Table of Contents
Fetching ...

On Inference Stability for Diffusion Models

Viet Nguyen, Giang Vu, Tung Nguyen Thanh, Khoat Than, Toan Tran

TL;DR

The paper tackles inference stability in diffusion probabilistic models by identifying that sequential timestep correlations induce an accumulation of error during sampling. It introduces a Sequence-Aware (SA) loss, which constrains a local sum of prediction errors over K consecutive steps and is combined with the standard loss as $\mathcal{L}=\mathcal{L}_{simple}+\lambda\mathcal{L}_{sa}$; the authors prove that this SA loss provides a tighter upper bound on the total estimation gap than conventional losses. Empirically, SA-DPM yields significant improvements in FID and Inception Score on CIFAR10, CelebA, and CelebA-HQ, especially when sampling with few timesteps, and can further boost results when paired with covariance-estimation methods (Analytic-DPM, NPR-DPM, SN-DPM). The approach enhances sampling quality and stability, offering a practical path to more efficient and reliable diffusion-based generation, with code available for reproduction. Overall, the work presents a principled, theoretically grounded and empirically validated method to reduce inference gaps in diffusion models, enabling better generalization and faster sampling.

Abstract

Denoising Probabilistic Models (DPMs) represent an emerging domain of generative models that excel in generating diverse and high-quality images. However, most current training methods for DPMs often neglect the correlation between timesteps, limiting the model's performance in generating images effectively. Notably, we theoretically point out that this issue can be caused by the cumulative estimation gap between the predicted and the actual trajectory. To minimize that gap, we propose a novel \textit{sequence-aware} loss that aims to reduce the estimation gap to enhance the sampling quality. Furthermore, we theoretically show that our proposed loss function is a tighter upper bound of the estimation loss in comparison with the conventional loss in DPMs. Experimental results on several benchmark datasets including CIFAR10, CelebA, and CelebA-HQ consistently show a remarkable improvement of our proposed method regarding the image generalization quality measured by FID and Inception Score compared to several DPM baselines. Our code and pre-trained checkpoints are available at \url{https://github.com/VinAIResearch/SA-DPM}.

On Inference Stability for Diffusion Models

TL;DR

The paper tackles inference stability in diffusion probabilistic models by identifying that sequential timestep correlations induce an accumulation of error during sampling. It introduces a Sequence-Aware (SA) loss, which constrains a local sum of prediction errors over K consecutive steps and is combined with the standard loss as ; the authors prove that this SA loss provides a tighter upper bound on the total estimation gap than conventional losses. Empirically, SA-DPM yields significant improvements in FID and Inception Score on CIFAR10, CelebA, and CelebA-HQ, especially when sampling with few timesteps, and can further boost results when paired with covariance-estimation methods (Analytic-DPM, NPR-DPM, SN-DPM). The approach enhances sampling quality and stability, offering a practical path to more efficient and reliable diffusion-based generation, with code available for reproduction. Overall, the work presents a principled, theoretically grounded and empirically validated method to reduce inference gaps in diffusion models, enabling better generalization and faster sampling.

Abstract

Denoising Probabilistic Models (DPMs) represent an emerging domain of generative models that excel in generating diverse and high-quality images. However, most current training methods for DPMs often neglect the correlation between timesteps, limiting the model's performance in generating images effectively. Notably, we theoretically point out that this issue can be caused by the cumulative estimation gap between the predicted and the actual trajectory. To minimize that gap, we propose a novel \textit{sequence-aware} loss that aims to reduce the estimation gap to enhance the sampling quality. Furthermore, we theoretically show that our proposed loss function is a tighter upper bound of the estimation loss in comparison with the conventional loss in DPMs. Experimental results on several benchmark datasets including CIFAR10, CelebA, and CelebA-HQ consistently show a remarkable improvement of our proposed method regarding the image generalization quality measured by FID and Inception Score compared to several DPM baselines. Our code and pre-trained checkpoints are available at \url{https://github.com/VinAIResearch/SA-DPM}.
Paper Structure (21 sections, 3 theorems, 35 equations, 10 figures, 4 tables, 3 algorithms)

This paper contains 21 sections, 3 theorems, 35 equations, 10 figures, 4 tables, 3 algorithms.

Key Result

Theorem 1

Let $\boldsymbol{f}_{\theta}(\boldsymbol{x}_{s}, s)$ be a noise predictor with parameter $\theta$. Its total gap from step 2 to $T$, for each $\boldsymbol{x}_{0}$, is where $\tau_{i} = \frac{\sqrt{\bar{\alpha}_{i-1}}(1-\bar{\alpha}_{1})}{\sqrt{\alpha_{1}}(1-\bar{\alpha}_{i-1})}\gamma_{1, i} \frac{\sqrt{1-\bar{\alpha}_{i}}}{\sqrt{\bar{\alpha}_{i}}}$. Furthermore, the total loss of $\boldsymbol{f}_

Figures (10)

  • Figure 1: 1-D example of sampling trajectory. Under the assumption that the error at each timestep is similar: (a) the cumulative error by steps is large while (b) the cumulative error by steps is small. This behavior is due to the correlation between neighbor timesteps.
  • Figure 2: Qualitative results of CelebA-HQ 256 $\times$ 256.
  • Figure 3: Qualitative results of (a) CIFAR10 32 $\times$ 32. (b) CelebA 64 $\times$ 64.
  • Figure 4: Total gap term $\bar{d}_{\theta, t}$ when sampling image starting from $\boldsymbol{x}_{300}$ on CIFAR10 dataset.
  • Figure 5: Total gap term $\bar{d}_{\theta, t}$ when sampling image starting from $\boldsymbol{x}_{300}$ on CIFAR10 dataset (a, b, c) and CelebA $64 \times 64$ dataset (d). $\lambda = 0$ denotes the Vanilla DPM.
  • ...and 5 more figures

Theorems & Definitions (3)

  • Theorem 1: Estimation gap
  • Theorem 2
  • Lemma 3