Variational Diffusion Posterior Sampling with Midpoint Guidance
Badr Moufad, Yazid Janati, Lisa Bedin, Alain Durmus, Randal Douc, Eric Moulines, Jimmy Olsson
TL;DR
This paper tackles posterior sampling in Bayesian inverse problems when using diffusion-model priors, where the posterior denoising term includes intractable guidance. It introduces Midpoint Guidance Posterior Sampling (MGPS), which uses a midpoint decomposition of the backward transition and a Gaussian variational backward kernel to approximate the guidance term, trading off prior dynamics for guidance accuracy. MGPS demonstrates strong empirical performance across Gaussian mixtures, linear/nonlinear image tasks, latent diffusion priors, and ECG reconstruction, outperforming several state-of-the-art posterior samplers while maintaining competitive runtimes. The authors also provide publicly available code to reproduce the experiments.
Abstract
Diffusion models have recently shown considerable potential in solving Bayesian inverse problems when used as priors. However, sampling from the resulting denoising posterior distributions remains a challenge as it involves intractable terms. To tackle this issue, state-of-the-art approaches formulate the problem as that of sampling from a surrogate diffusion model targeting the posterior and decompose its scores into two terms: the prior score and an intractable guidance term. While the former is replaced by the pre-trained score of the considered diffusion model, the guidance term has to be estimated. In this paper, we propose a novel approach that utilises a decomposition of the transitions which, in contrast to previous methods, allows a trade-off between the complexity of the intractable guidance term and that of the prior transitions. We validate the proposed approach through extensive experiments on linear and nonlinear inverse problems, including challenging cases with latent diffusion models as priors. We then demonstrate its applicability to various modalities and its promising impact on public health by tackling cardiovascular disease diagnosis through the reconstruction of incomplete electrocardiograms. The code is publicly available at \url{https://github.com/yazidjanati/mgps}.
