Table of Contents
Fetching ...

Learning Diffusion Priors from Observations by Expectation Maximization

François Rozet, Gérôme Andry, François Lanusse, Gilles Louppe

TL;DR

This work addresses learning high-quality diffusion priors when only noisy, incomplete observations are available. It casts diffusion-prior training as an empirical Bayes problem solved by an Expectation-Maximization framework (DiEM), and it introduces Moment Matching Posterior Sampling (MMPS) to produce accurate posterior samples without destabilizing the diffusion process. The method is validated on a low-dimensional manifold, corrupted CIFAR-10, and accelerated MRI, showing that the learned diffusion priors yield plausible, data-consistent reconstructions and can handle substantial information loss. The approach enables diffusion priors to be used reliably in scientific inverse problems where clean, large datasets are scarce, potentially broadening the applicability of diffusion-based Bayesian inference.

Abstract

Diffusion models recently proved to be remarkable priors for Bayesian inverse problems. However, training these models typically requires access to large amounts of clean data, which could prove difficult in some settings. In this work, we present DiEM, a novel method based on the expectation-maximization algorithm for training diffusion models from incomplete and noisy observations only. Unlike previous works, DiEM leads to proper diffusion models, which is crucial for downstream tasks. As part of our methods, we propose and motivate an improved posterior sampling scheme for unconditional diffusion models. We present empirical evidence supporting the effectiveness of our approach.

Learning Diffusion Priors from Observations by Expectation Maximization

TL;DR

This work addresses learning high-quality diffusion priors when only noisy, incomplete observations are available. It casts diffusion-prior training as an empirical Bayes problem solved by an Expectation-Maximization framework (DiEM), and it introduces Moment Matching Posterior Sampling (MMPS) to produce accurate posterior samples without destabilizing the diffusion process. The method is validated on a low-dimensional manifold, corrupted CIFAR-10, and accelerated MRI, showing that the learned diffusion priors yield plausible, data-consistent reconstructions and can handle substantial information loss. The approach enables diffusion priors to be used reliably in scientific inverse problems where clean, large datasets are scarce, potentially broadening the applicability of diffusion-based Bayesian inference.

Abstract

Diffusion models recently proved to be remarkable priors for Bayesian inverse problems. However, training these models typically requires access to large amounts of clean data, which could prove difficult in some settings. In this work, we present DiEM, a novel method based on the expectation-maximization algorithm for training diffusion models from incomplete and noisy observations only. Unlike previous works, DiEM leads to proper diffusion models, which is crucial for downstream tasks. As part of our methods, we propose and motivate an improved posterior sampling scheme for unconditional diffusion models. We present empirical evidence supporting the effectiveness of our approach.
Paper Structure (32 sections, 1 theorem, 21 equations, 20 figures, 5 tables, 4 algorithms)

This paper contains 32 sections, 1 theorem, 21 equations, 20 figures, 5 tables, 4 algorithms.

Key Result

Theorem 1

For any distribution $p(x)$ and $p(x_t \mid x) = \mathcal{N}(x_t \mid x, \Sigma_t)$, the first and second moments of the distribution $p(x \mid x_t)$ are linked to the score function $\nabla_{\!{x_t}} \log p(x_t)$ through

Figures (20)

  • Figure 1: Illustration of the posterior $q(x_t \mid y)$ for the Gaussian approximation $q(x \mid x_t)$ when the prior $p(x)$ lies on a manifold. Ellipses represent 95 credible regions of $q(x \mid x_t)$. (A) With $\Sigma_t$ as heuristic for $\mathbb{V}[x \mid x_t]$, any $x_t$ whose mean $\mathbb{E}[x \mid x_t]$ is close to the plane $y =\! Ax$ is considered likely. (B) With $\mathbb{V}[x \mid x_t]$, more regions are correctly pruned. (C) Ground-truth $p(x_t \mid y)$ and $p(x \mid x_t)$ for reference.
  • Figure 1: Evaluation of final models trained on corrupted CIFAR-10. DiEM outperforms AmbientDiffusion Daras2023Ambient at similar corruption levels $\rho$. Using heuristics for $\mathbb{V}[x \mid x_t]$ instead of Tweedie's formula greatly decreases the sample quality.
  • Figure 2: Sinkhorn divergence Chizat2020Faster between the posteriors $p(x_t \mid y)$ and $q(x_t \mid y)$ for different heuristics of $\mathbb{V}[x \mid x_t]$ when the prior $p(x)$ lies on 1-d manifolds embedded in $\mathbb{R}^3$. Lines and shades represent the 25-50-75 percentiles for 64 randomly generated manifolds Zenke2021Remarkable and measurement matrices $A \in \mathbb{R}^{1 \times 3}$. Using $\mathbb{V}[x \mid x_t]$ instead of heuristics leads to orders of magnitude more accurate posteriors $q(x_t \mid y)$.
  • Figure 3: Illustration of 2-d marginals of the model $q_{\theta_k}(x)$ along the EM iterations. The initial Gaussian prior $q_0(x)$ leads to a very dispersed first model $q_{\theta_1}(x)$. The EM algorithm gradually prunes the density regions which are inconsistent with observations, until it reaches a stationary distribution. The marginals of the final distribution are close to the marginals of the ground-truth distribution.
  • Figure 4: FID of $q_{\theta_k}(x)$ along the EM iterations for the corrupted CIFAR-10 experiment.
  • ...and 15 more figures

Theorems & Definitions (3)

  • Theorem 1
  • proof
  • proof