Table of Contents
Fetching ...

Tackling the Problem of Distributional Shifts: Correcting Misspecified, High-Dimensional Data-Driven Priors for Inverse Problems

Gabriel Missael Barco, Alexandre Adam, Connor Stone, Yashar Hezaveh, Laurence Perreault-Levasseur

TL;DR

The paper tackles distributional shifts in data-driven priors for Bayesian inverse problems by proposing an iterative, posterior-sample–driven updating scheme for score-based priors. It employs score-based diffusion models to encode flexible population priors and uses a Gaussian-perturbed linear forward model, as in strong gravitational lensing, to generate and refine posterior samples. Empirical results on MNIST and galaxy imaging show that starting from misspecified priors, the updates converge toward the true population distribution and yield less biased posterior reconstructions, with metrics like log-likelihood of residuals and PQMass supporting convergence. This approach enables more reliable inverse problem solutions under distribution shifts, with potential impact for large sky surveys such as LSST and Euclid.

Abstract

Bayesian inference for inverse problems hinges critically on the choice of priors. In the absence of specific prior information, population-level distributions can serve as effective priors for parameters of interest. With the advent of machine learning, the use of data-driven population-level distributions (encoded, e.g., in a trained deep neural network) as priors is emerging as an appealing alternative to simple parametric priors in a variety of inverse problems. However, in many astrophysical applications, it is often difficult or even impossible to acquire independent and identically distributed samples from the underlying data-generating process of interest to train these models. In these cases, corrupted data or a surrogate, e.g. a simulator, is often used to produce training samples, meaning that there is a risk of obtaining misspecified priors. This, in turn, can bias the inferred posteriors in ways that are difficult to quantify, which limits the potential applicability of these models in real-world scenarios. In this work, we propose addressing this issue by iteratively updating the population-level distributions by retraining the model with posterior samples from different sets of observations, and we showcase the potential of this method on the problem of background image reconstruction in strong gravitational lensing when score-based models are used as data-driven priors. We show that, starting from a misspecified prior distribution, the updated distribution becomes progressively closer to the underlying population-level distribution, and the resulting posterior samples exhibit reduced bias after several updates.

Tackling the Problem of Distributional Shifts: Correcting Misspecified, High-Dimensional Data-Driven Priors for Inverse Problems

TL;DR

The paper tackles distributional shifts in data-driven priors for Bayesian inverse problems by proposing an iterative, posterior-sample–driven updating scheme for score-based priors. It employs score-based diffusion models to encode flexible population priors and uses a Gaussian-perturbed linear forward model, as in strong gravitational lensing, to generate and refine posterior samples. Empirical results on MNIST and galaxy imaging show that starting from misspecified priors, the updates converge toward the true population distribution and yield less biased posterior reconstructions, with metrics like log-likelihood of residuals and PQMass supporting convergence. This approach enables more reliable inverse problem solutions under distribution shifts, with potential impact for large sky surveys such as LSST and Euclid.

Abstract

Bayesian inference for inverse problems hinges critically on the choice of priors. In the absence of specific prior information, population-level distributions can serve as effective priors for parameters of interest. With the advent of machine learning, the use of data-driven population-level distributions (encoded, e.g., in a trained deep neural network) as priors is emerging as an appealing alternative to simple parametric priors in a variety of inverse problems. However, in many astrophysical applications, it is often difficult or even impossible to acquire independent and identically distributed samples from the underlying data-generating process of interest to train these models. In these cases, corrupted data or a surrogate, e.g. a simulator, is often used to produce training samples, meaning that there is a risk of obtaining misspecified priors. This, in turn, can bias the inferred posteriors in ways that are difficult to quantify, which limits the potential applicability of these models in real-world scenarios. In this work, we propose addressing this issue by iteratively updating the population-level distributions by retraining the model with posterior samples from different sets of observations, and we showcase the potential of this method on the problem of background image reconstruction in strong gravitational lensing when score-based models are used as data-driven priors. We show that, starting from a misspecified prior distribution, the updated distribution becomes progressively closer to the underlying population-level distribution, and the resulting posterior samples exhibit reduced bias after several updates.
Paper Structure (23 sections, 20 equations, 15 figures, 1 algorithm)

This paper contains 23 sections, 20 equations, 15 figures, 1 algorithm.

Figures (15)

  • Figure 1: Impact of having a misspecified prior for posterior sampling in strong gravitational lensing source reconstruction. The first panel ($\mathbf{x}^\star$) displays the true source spiral galaxy, followed by the observed lensed image $\mathbf{y}$, which has Gaussian additive noise. Subsequent panels illustrate reconstructions using $p_e(\mathbf{x})$ (trained on elliptical galaxies, out-of-distribution) and $p_s(\mathbf{x})$ (trained on spiral galaxies, in-distribution). The reconstruction with $p_e(\mathbf{x})$ is biased due to the mismatch between the training data distribution and the distribution the observation was generated from, demonstrating the critical need for correct prior selection to ensure accurate source reconstructions.
  • Figure 2: Graphical model of the inference problem. The true prior distribution, parametrized by the population-level parameters $\theta^\star$, generates the realizations of unobserved parameters of interests $\mathbf{x}_i$. Our goal is to learn an estimate $\hat{\theta} \approx \theta^\star$. In this work, we have access to the noise distribution that generates $\boldsymbol{\eta}_i$, the forward model $A$ and a set of $N$ observations $\{\mathbf{y}_i\}_{i=1}^N$.
  • Figure 3: This figure demonstrates, with prior samples, the effectiveness of model updates in learning and forgetting specific digits in the MNIST experiment under observational noise of $\sigma_{\boldsymbol{\eta}} = 0.5$. Left: Samples of initial model $p_{\theta_0}$. These samples include the digit $6$ and lack the digit $4$, in agreement with the initial model's training distribution. Right: Samples from the final prior $p_{\theta_4}$. In this model, the digit $6$ is absent, and digits resembling $4$ are present, showcasing successful adaptation to the target distribution. However, sample quality is variable, as seen in the top left sample, where it is ambiguous whether it represents a $1$ or a $7$.
  • Figure 4: Learning and forgetting dynamics across updates using \ref{['alg:alg']} towards the target distribution in the MNIST experiment with observational noise of $\sigma_{\boldsymbol{\eta}} = 0.4$ during the updates. The plot shows the classification of $2\,048$ prior samples $\mathbf{x} \sim p_{\theta_\alpha}(\mathbf{x})$ at each update by a CNN classifier. Each panel corresponds to a digit category, with proportions of samples being shown for each update. Initially, the prior was trained excluding digits $4$ and $1$, while the target distribution excluded $6$ and $1$. The objective was to forget digit $6$, which was accomplished in a single update, and to learn digit $4$, which took several updates, initially producing classifications similar to digit $9$. The red dashed line represents the proportion from the target distribution, serving as a benchmark.
  • Figure 5: Sequences showcasing the model's ability to accurately reconstruct the digit $4$ from noisy observations in posterior samples after updating the prior with \ref{['alg:alg']}, despite $4$ not being included in the initial prior. The evolution of posterior samples $\mathbf{x} \sim p_{\theta_\alpha}(\mathbf{x} \mid \mathbf{y})$ is shown for each update, highlighting the model's improvement. Starting from the observed image $\mathbf{y}$ (leftmost), the samples progress from $\theta_0$ to $\theta_4$ (left to right). Initially, the model confuses $4$ with $9$ under $\theta_0$ and $\theta_1$. By $\theta_2$ and, more clearly, by $\theta_4$, the posterior accurately reflects $4$, closely matching the true digit $\mathbf{x}^\star$ (rightmost). This illustrates the effectiveness of the model in adapting its posterior estimates over sequential updates to learn and correct its representation of digit $4$ from noisy data.
  • ...and 10 more figures

Theorems & Definitions (1)

  • Definition 2.1