Table of Contents
Fetching ...

An Expectation-Maximization Algorithm for Training Clean Diffusion Models from Corrupted Observations

Weimin Bai, Yifei Wang, Wenzheng Chen, He Sun

TL;DR

EMDiffusion presents an EM framework to train diffusion models from corrupted observations by alternating between reconstructing clean images (E-step) and updating the diffusion prior (M-step). Starting from a small set of clean data, the method uses diffusion posterior sampling with an adaptive scaling to prevent mode collapse and progressively refines the score-based prior. Across CIFAR-10 and CelebA, EMDiffusion achieves state-of-the-art or competitive results in inpainting, denoising, and deblurring, often matching or approaching performance of methods that rely on clean priors while using corrupted data alone for training. This approach enables practical deployment of learned diffusion priors in settings where large clean datasets are unavailable, with strong implications for real-world computational imaging tasks.

Abstract

Diffusion models excel in solving imaging inverse problems due to their ability to model complex image priors. However, their reliance on large, clean datasets for training limits their practical use where clean data is scarce. In this paper, we propose EMDiffusion, an expectation-maximization (EM) approach to train diffusion models from corrupted observations. Our method alternates between reconstructing clean images from corrupted data using a known diffusion model (E-step) and refining diffusion model weights based on these reconstructions (M-step). This iterative process leads the learned diffusion model to gradually converge to the true clean data distribution. We validate our method through extensive experiments on diverse computational imaging tasks, including random inpainting, denoising, and deblurring, achieving new state-of-the-art performance.

An Expectation-Maximization Algorithm for Training Clean Diffusion Models from Corrupted Observations

TL;DR

EMDiffusion presents an EM framework to train diffusion models from corrupted observations by alternating between reconstructing clean images (E-step) and updating the diffusion prior (M-step). Starting from a small set of clean data, the method uses diffusion posterior sampling with an adaptive scaling to prevent mode collapse and progressively refines the score-based prior. Across CIFAR-10 and CelebA, EMDiffusion achieves state-of-the-art or competitive results in inpainting, denoising, and deblurring, often matching or approaching performance of methods that rely on clean priors while using corrupted data alone for training. This approach enables practical deployment of learned diffusion priors in settings where large clean datasets are unavailable, with strong implications for real-world computational imaging tasks.

Abstract

Diffusion models excel in solving imaging inverse problems due to their ability to model complex image priors. However, their reliance on large, clean datasets for training limits their practical use where clean data is scarce. In this paper, we propose EMDiffusion, an expectation-maximization (EM) approach to train diffusion models from corrupted observations. Our method alternates between reconstructing clean images from corrupted data using a known diffusion model (E-step) and refining diffusion model weights based on these reconstructions (M-step). This iterative process leads the learned diffusion model to gradually converge to the true clean data distribution. We validate our method through extensive experiments on diverse computational imaging tasks, including random inpainting, denoising, and deblurring, achieving new state-of-the-art performance.
Paper Structure (36 sections, 13 equations, 8 figures, 4 tables, 1 algorithm)

This paper contains 36 sections, 13 equations, 8 figures, 4 tables, 1 algorithm.

Figures (8)

  • Figure 1: Overview of EMDiffusion. The paper proposes an expectation-maximization (EM) approach to jointly solve imaging inverse problems and train a diffusion model from corrupted observations. Left: In each E-step, we assume a known diffusion model and perform posterior sampling to reconstruct images from corrupted observations. In the M-step, we update the weights of the diffusion model based on these posterior samples. By iteratively alternating between these two steps, the diffusion model gradually learns the clean image distribution and generates high-quality posterior samples. Right: Raw observations and reconstructed clean images based on the diffusion model learned from corrupted data.
  • Figure 2: Adaptive diffusion posterior sampling on CIFAR-10 inpainting. (a) Corrupted observations from the test set, with 60% of the pixels masked in each image. (b), (c), and (d) Diffusion posterior samples with the diffusion prior weighted by different scaling factors: $\lambda=1, 10, 20$. The diffusion prior is pre-trained using the 50 clean images shown in (e). When $\lambda$ is small, there is obvious mode collapse, and all posterior samples come from the training set of 50 clean images, unrelated to the observations. As $\lambda$ increases, the data likelihood gains more significance, resulting in reconstructed images that are more consistent with the inpainting observations.
  • Figure 3: Results on CIFAR-10 inpainting. In each image, 60% of the pixels are masked. As the EM iterations progress, the diffusion model learns cleaner prior distributions, improving the quality of posterior samples. Our method significantly outperforms the baselines, SURE-Score and AmbientDiffusion, achieving reconstruction quality comparable to DPS with a clean prior.
  • Figure 4: Results on (a) CIFAR-10 denoising and (b) CelebA deblurring. Our method significantly outperforms the baselines, SURE-Score and AmbientDiffusion.
  • Figure 5: Ablation studies. (a) PSNR of diffusion posterior samples generated by the initial diffusion models trained on different amounts (10, 50, 100, 500) or types (in-distribution or out-of-distribution) of clean data. (b) FID scores of learned diffusion models after each EM iteration. The diffusion model trained on 50,000 corrupted images achieves a similar performance to those trained on 15,000-20,000 clean images. (c) PSNR of diffusion posterior samples weighted by different scaling factors $\lambda$ at each stage. The optimal $\lambda$ for posterior sampling decreases as the EM iterations progress.
  • ...and 3 more figures