Table of Contents
Fetching ...

Improving Probabilistic Diffusion Models With Optimal Diagonal Covariance Matching

Zijing Ou, Mingtian Zhang, Andi Zhang, Tim Z. Xiao, Yingzhen Li, David Barber

TL;DR

Diffusion models rely on the covariance of the denoising distribution, and fixed or heuristic covariances limit sampling efficiency and density estimation. The paper introduces Optimal Covariance Matching (OCM), an unbiased objective that learns the diagonal of the optimal state-dependent diagonal covariance from the score function, enabling efficient covariance estimation with minimal overhead. The approach applies to DDPM, DDIM, and latent diffusion, yielding improved FID and NLL with fewer denoising steps across CIFAR-10, CelebA, LSUN, and ImageNet experiments, and scales to large latent-space models like DiT. By linking covariance learning to the score-based framework and maintaining tractable computation, OCM offers a practical route to faster, more accurate diffusion-based density estimation and generation in vision tasks.

Abstract

The probabilistic diffusion model has become highly effective across various domains. Typically, sampling from a diffusion model involves using a denoising distribution characterized by a Gaussian with a learned mean and either fixed or learned covariances. In this paper, we leverage the recently proposed covariance moment matching technique and introduce a novel method for learning the diagonal covariance. Unlike traditional data-driven diagonal covariance approximation approaches, our method involves directly regressing the optimal diagonal analytic covariance using a new, unbiased objective named Optimal Covariance Matching (OCM). This approach can significantly reduce the approximation error in covariance prediction. We demonstrate how our method can substantially enhance the sampling efficiency, recall rate and likelihood of commonly used diffusion models.

Improving Probabilistic Diffusion Models With Optimal Diagonal Covariance Matching

TL;DR

Diffusion models rely on the covariance of the denoising distribution, and fixed or heuristic covariances limit sampling efficiency and density estimation. The paper introduces Optimal Covariance Matching (OCM), an unbiased objective that learns the diagonal of the optimal state-dependent diagonal covariance from the score function, enabling efficient covariance estimation with minimal overhead. The approach applies to DDPM, DDIM, and latent diffusion, yielding improved FID and NLL with fewer denoising steps across CIFAR-10, CelebA, LSUN, and ImageNet experiments, and scales to large latent-space models like DiT. By linking covariance learning to the score-based framework and maintaining tractable computation, OCM offers a practical route to faster, more accurate diffusion-based density estimation and generation in vision tasks.

Abstract

The probabilistic diffusion model has become highly effective across various domains. Typically, sampling from a diffusion model involves using a denoising distribution characterized by a Gaussian with a learned mean and either fixed or learned covariances. In this paper, we leverage the recently proposed covariance moment matching technique and introduce a novel method for learning the diagonal covariance. Unlike traditional data-driven diagonal covariance approximation approaches, our method involves directly regressing the optimal diagonal analytic covariance using a new, unbiased objective named Optimal Covariance Matching (OCM). This approach can significantly reduce the approximation error in covariance prediction. We demonstrate how our method can substantially enhance the sampling efficiency, recall rate and likelihood of commonly used diffusion models.
Paper Structure (34 sections, 8 theorems, 41 equations, 14 figures, 13 tables, 2 algorithms)

This paper contains 34 sections, 8 theorems, 41 equations, 14 figures, 13 tables, 2 algorithms.

Key Result

Theorem 1

Given a joint distribution $q(\tilde{x},x)=q(\tilde{x}|x)q(x)$ with $q(\tilde{x}|x)=\mathcal{N}(\alpha x,\sigma^2 I)$, then the covariance of the true posterior $q(x|\tilde{x})\propto q(x)q(\tilde{x}|x)$, which is defined as $\Sigma(\tilde{x})=\mathbb{E}_{q(x|\tilde{x})} [x^2] - \mathbb{E}_{q(x|\til

Figures (14)

  • Figure 1: Comparisons of different covariance estimation methods. Figure (a) demonstrates the training data and the ground truth density. Figures (b) and (c) present the MMD evaluation against the total sampling steps in the DDPM (b) and DDIM (c) settings.
  • Figure 2: The results of FID v.s. NLL for different methods with varying numbers of sampling steps on CIFAR10 (CS). Our method consistently achieves the best trade-off between FID and NLL.
  • Figure 3: Results of DiT training on ImageNet 256x256. We generate samples using 10 timesteps with varying CFG coefficients (see \ref{['tab:dit-fid-recall']} for exact numerical values).
  • Figure 4: Comparisons of different covariance estimation methods based on estimation error and sample generation quality. Figures (a) and (b) show the mean square error of the estimated diagonal covariance under the assumptions: (a) access to the true score, and (b) learned score of the data distribution at various noise levels. Figures (c) and (d) present the MMD evaluation against the total sampling steps in the DDPM (c) and DDIM (d) settings. We can find the proposed OCM method can achieve the lowest estimation error and consistently outperform other baseline methods when fewer generation steps are applied.
  • Figure 5: Diagonal covariance estimation visualisation with different Rademacher sample numbers.
  • ...and 9 more figures

Theorems & Definitions (14)

  • Theorem 1: Generalized Analytical Covariance Identity
  • Theorem 2: Validity of the OCM objective
  • Theorem 2: Generalized Analytical Covariance Identity
  • proof
  • Theorem 2: Validity of the OCM objective
  • proof
  • Lemma 1: First order Tweedie's formula efron2011tweedie
  • Lemma 2: Second order Tweedie's formula
  • proof
  • Lemma 3: Convert the covariance of $q(\tilde{x} | x)$ to the hessian of $q(\tilde{x})$
  • ...and 4 more