Table of Contents
Fetching ...

Solving Prior Distribution Mismatch in Diffusion Models via Optimal Transport

Zhanpeng Wang, Shenghao Li, Jiameng Che, Chen Wang, Shangling Jui, Na Lei, Zhongxuan Luo

TL;DR

The paper identifies a fundamental prior distribution mismatch in diffusion models: the forward terminal distribution $p_T$ often does not match the reverse initial distribution $q_T$, causing non-zero SNR and accumulated denoising errors that degrade sampling. It introduces an Optimal Transport–based prior error eliminator that constructs the OT map $\nabla u_T^{\gets}$ from the forward terminal distribution $p_T$ (via pushing forward the steady-state $p_\infty$) to align the reverse process, with $q_T = \nabla u_T^{\gets}(p_\infty)$. The authors provide a Wasserstein-2 distance upper bound tying the remaining error to both the score-matching objective $\mathcal{J}_{SM}$ and the OT map approximation error, and they establish asymptotic consistency between dynamic OT and probability flow. Empirically, the method fully eliminates prior error in discrete settings and yields improved generation quality and accelerated sampling across multiple image datasets, validating both theoretical guarantees and practical utility. Overall, the work offers a rigorous, universal framework for improving diffusion model performance by rigorously correcting distribution alignment via OT, with clear implications for faster and more faithful generative sampling.

Abstract

Diffusion Models (DMs) have achieved remarkable progress in generative modeling. However, the mismatch between the forward terminal distribution and reverse initial distribution introduces prior error, leading to deviations of sampling trajectories from the true distribution and severely limiting model performance. This issue further triggers cascading problems, including non-zero Signal-to-Noise Ratio, accumulated denoising errors, degraded generation quality, and constrained sampling efficiency. To address this issue, this paper proposes a prior error elimination framework based on Optimal Transport (OT). Specifically, an OT map from the reverse initial distribution to the forward terminal distribution is constructed to achieve precise matching of the two distributions. Meanwhile, the upper bound of the prior error is quantified using the Wasserstein distance, proving that the prior error can be effectively eliminated via the OT map. Additionally, by deriving the asymptotic consistency between dynamic OT and probability flow, this method is revealed to be highly compatible with the intrinsic mechanism of the diffusion process. Experimental results demonstrate that the proposed method completely eliminates the prior error both theoretically and practically, providing a universal and rigorous solution for optimizing the performance of DMs.

Solving Prior Distribution Mismatch in Diffusion Models via Optimal Transport

TL;DR

The paper identifies a fundamental prior distribution mismatch in diffusion models: the forward terminal distribution often does not match the reverse initial distribution , causing non-zero SNR and accumulated denoising errors that degrade sampling. It introduces an Optimal Transport–based prior error eliminator that constructs the OT map from the forward terminal distribution (via pushing forward the steady-state ) to align the reverse process, with . The authors provide a Wasserstein-2 distance upper bound tying the remaining error to both the score-matching objective and the OT map approximation error, and they establish asymptotic consistency between dynamic OT and probability flow. Empirically, the method fully eliminates prior error in discrete settings and yields improved generation quality and accelerated sampling across multiple image datasets, validating both theoretical guarantees and practical utility. Overall, the work offers a rigorous, universal framework for improving diffusion model performance by rigorously correcting distribution alignment via OT, with clear implications for faster and more faithful generative sampling.

Abstract

Diffusion Models (DMs) have achieved remarkable progress in generative modeling. However, the mismatch between the forward terminal distribution and reverse initial distribution introduces prior error, leading to deviations of sampling trajectories from the true distribution and severely limiting model performance. This issue further triggers cascading problems, including non-zero Signal-to-Noise Ratio, accumulated denoising errors, degraded generation quality, and constrained sampling efficiency. To address this issue, this paper proposes a prior error elimination framework based on Optimal Transport (OT). Specifically, an OT map from the reverse initial distribution to the forward terminal distribution is constructed to achieve precise matching of the two distributions. Meanwhile, the upper bound of the prior error is quantified using the Wasserstein distance, proving that the prior error can be effectively eliminated via the OT map. Additionally, by deriving the asymptotic consistency between dynamic OT and probability flow, this method is revealed to be highly compatible with the intrinsic mechanism of the diffusion process. Experimental results demonstrate that the proposed method completely eliminates the prior error both theoretically and practically, providing a universal and rigorous solution for optimizing the performance of DMs.

Paper Structure

This paper contains 16 sections, 7 theorems, 22 equations, 4 figures, 3 tables, 2 algorithms.

Key Result

Theorem 3.2

If $q_{T}=p_{\infty}$ and $\boldsymbol{S}_{\boldsymbol{\theta}}\left(\boldsymbol{x},t\right)\equiv\nabla\log p_{t}\left(\boldsymbol{x}\right)$ on $\mathbb{R}^{n}\times[\varepsilon,T]$, let $q_\varepsilon$ denote the distribution generated by DMs. Then $\mathcal{W}_2(p_\varepsilon, q_\varepsilon)$ ad where $I(t)=\exp\left(\int_{0}^{t}(f(\tau)+g(\tau)^{2}L_{\boldsymbol{S}_{\boldsymbol{\theta}}}(\ta

Figures (4)

  • Figure 1: In the forward process, DDPM ho2020denoising corrupts the true data distribution $p_{0}$, directing it toward the Gaussian distribution $\mathcal{N}(\boldsymbol{0},\boldsymbol{I})$ but practically terminating at $p_{T}$. In the reverse process, it adopts $q_{T}=\mathcal{N}(\boldsymbol{0},\boldsymbol{I})$ as the prior, with this mismatch inducing the prior error $\mathcal{W}_{2}(p_{T},q_{T})$, which increases the Wasserstein difference between the generative distribution $q_{0}$ and $p_{0}$. To mitigate this adverse effect, DDPM can only increase diffusion time $T$, leading to elevated training costs, low sampling efficiency, and excessive error accumulation. In contrast, our method, which eliminates the prior error via OT map, is more robust to variations in $T$, thus enabling extension to accelerated sampling.
  • Figure 2: Generalization brought by $T$ and $\tau$. Red: original image.
  • Figure 3: Generated images on four datasets with OT accelerator.
  • Figure 4: MMR in \ref{['eq:MMR']} quantifies the degree of mode mixture, based on that mixed and unmixed images exhibit distinct behaviors under a classifier. Subfigure (b) illustrates the variation of MMR ($\lambda=0.2$) with respect to $T$ for different truncated models. A lower MMR indicates that generated images are more class-discriminative and well-separated.

Theorems & Definitions (13)

  • Remark 3.1
  • Theorem 3.2: Proof in Appendix E.4
  • Remark 3.3
  • Theorem 3.4: Proof in Appendix E.6
  • Lemma 3.5: Proof in Appendix E.7
  • Theorem 3.6: Proof in Appendix E.8
  • Corollary 3.7: Proof in Appendix E.9
  • Remark 3.8
  • Remark 3.9
  • Corollary 3.10: Proof in Appendix E.10
  • ...and 3 more