Table of Contents
Fetching ...

Variational Diffusion Posterior Sampling with Midpoint Guidance

Badr Moufad, Yazid Janati, Lisa Bedin, Alain Durmus, Randal Douc, Eric Moulines, Jimmy Olsson

TL;DR

This paper tackles posterior sampling in Bayesian inverse problems when using diffusion-model priors, where the posterior denoising term includes intractable guidance. It introduces Midpoint Guidance Posterior Sampling (MGPS), which uses a midpoint decomposition of the backward transition and a Gaussian variational backward kernel to approximate the guidance term, trading off prior dynamics for guidance accuracy. MGPS demonstrates strong empirical performance across Gaussian mixtures, linear/nonlinear image tasks, latent diffusion priors, and ECG reconstruction, outperforming several state-of-the-art posterior samplers while maintaining competitive runtimes. The authors also provide publicly available code to reproduce the experiments.

Abstract

Diffusion models have recently shown considerable potential in solving Bayesian inverse problems when used as priors. However, sampling from the resulting denoising posterior distributions remains a challenge as it involves intractable terms. To tackle this issue, state-of-the-art approaches formulate the problem as that of sampling from a surrogate diffusion model targeting the posterior and decompose its scores into two terms: the prior score and an intractable guidance term. While the former is replaced by the pre-trained score of the considered diffusion model, the guidance term has to be estimated. In this paper, we propose a novel approach that utilises a decomposition of the transitions which, in contrast to previous methods, allows a trade-off between the complexity of the intractable guidance term and that of the prior transitions. We validate the proposed approach through extensive experiments on linear and nonlinear inverse problems, including challenging cases with latent diffusion models as priors. We then demonstrate its applicability to various modalities and its promising impact on public health by tackling cardiovascular disease diagnosis through the reconstruction of incomplete electrocardiograms. The code is publicly available at \url{https://github.com/yazidjanati/mgps}.

Variational Diffusion Posterior Sampling with Midpoint Guidance

TL;DR

This paper tackles posterior sampling in Bayesian inverse problems when using diffusion-model priors, where the posterior denoising term includes intractable guidance. It introduces Midpoint Guidance Posterior Sampling (MGPS), which uses a midpoint decomposition of the backward transition and a Gaussian variational backward kernel to approximate the guidance term, trading off prior dynamics for guidance accuracy. MGPS demonstrates strong empirical performance across Gaussian mixtures, linear/nonlinear image tasks, latent diffusion priors, and ECG reconstruction, outperforming several state-of-the-art posterior samplers while maintaining competitive runtimes. The authors also provide publicly available code to reproduce the experiments.

Abstract

Diffusion models have recently shown considerable potential in solving Bayesian inverse problems when used as priors. However, sampling from the resulting denoising posterior distributions remains a challenge as it involves intractable terms. To tackle this issue, state-of-the-art approaches formulate the problem as that of sampling from a surrogate diffusion model targeting the posterior and decompose its scores into two terms: the prior score and an intractable guidance term. While the former is replaced by the pre-trained score of the considered diffusion model, the guidance term has to be estimated. In this paper, we propose a novel approach that utilises a decomposition of the transitions which, in contrast to previous methods, allows a trade-off between the complexity of the intractable guidance term and that of the prior transitions. We validate the proposed approach through extensive experiments on linear and nonlinear inverse problems, including challenging cases with latent diffusion models as priors. We then demonstrate its applicability to various modalities and its promising impact on public health by tackling cardiovascular disease diagnosis through the reconstruction of incomplete electrocardiograms. The code is publicly available at \url{https://github.com/yazidjanati/mgps}.

Paper Structure

This paper contains 51 sections, 1 theorem, 67 equations, 27 figures, 16 tables, 3 algorithms.

Key Result

Lemma 1

For all $k \in \llbracket 1, n-1 \rrbracket$ and ${\ell_{k}} \in \llbracket 0, k \rrbracket$ it holds that

Figures (27)

  • Figure 1: For each color, the different solid arrows indicate different conditional densities that need to be approximated for a given choice of the midpoint ${\ell_{k}}$. The longer the arrow, the more difficult it is to approximate the corresponding conditional density. By placing the ${\ell_{k}}$ midway between zero and $k$, the shortest arrows are obtained.
  • Figure 2: Left: average $W_2$ with $10\%$--$90\%$ quantile range. Right: distribution of the minimizing $\eta^*$.
  • Figure 3: MGPS sample images for half mask (left), expand task, Gaussian blur and motion blur (right) on the ImageNet dataset.
  • Figure 4: Left: SW as a function of $\eta$ with ${\ell_{k}} = \lfloor \eta k \rfloor$. Right: SW as a function of the number of gradient steps, for a specific choice of $({\ell_{k}})_k$.
  • Figure 5: MGPS samples with LDM on FFHQ dataset.
  • ...and 22 more figures

Theorems & Definitions (4)

  • Lemma 1
  • Example 2
  • Remark 3
  • proof : Proof of \ref{['lem:pi_bwker']}