Table of Contents
Fetching ...

Improved Sample Complexity Bounds for Diffusion Model Training

Shivam Gupta, Aditya Parulekar, Eric Price, Zhiyang Xun

TL;DR

This work addresses the sample complexity of training diffusion models, specifically learning neural-network-parameterized score estimates for DDPM sampling. By replacing the traditional $L^2$ accuracy requirement with a quantile-based $D^{\delta}$ measure, the authors achieve polylogarithmic dependence on the target Wasserstein accuracy $\gamma$ and exponential improvements in depth $D$ and Wasserstein error, while incurring a modest increase in the $\varepsilon$-dependence. They prove that a neural network class can learn scores to sufficient quantile accuracy using $m = \widetilde{O}(\frac{d^2 P D}{\varepsilon^3} \log \Theta \log^3 \frac{1}{\gamma})$ samples, and that this is enough for the DDPM reverse-SDE sampling to yield a distribution within $\varepsilon$ TV of $q_{\gamma}$, leveraging a careful combination of training concentration and a Girsanov-based analysis for sampling. The paper also presents information-theoretic hardness results for $L^2$ learning, motivating the new quantile-centric approach and providing a sharper lens on diffusion-model training theory with potential practical impact for training efficiency in high-dimensional generative modeling.

Abstract

Diffusion models have become the most popular approach to deep generative modeling of images, largely due to their empirical performance and reliability. From a theoretical standpoint, a number of recent works have studied the iteration complexity of sampling, assuming access to an accurate diffusion model. In this work, we focus on understanding the sample complexity of training such a model; how many samples are needed to learn an accurate diffusion model using a sufficiently expressive neural network? Prior work showed bounds polynomial in the dimension, desired Total Variation error, and Wasserstein error. We show an exponential improvement in the dependence on Wasserstein error and depth, along with improved dependencies on other relevant parameters.

Improved Sample Complexity Bounds for Diffusion Model Training

TL;DR

This work addresses the sample complexity of training diffusion models, specifically learning neural-network-parameterized score estimates for DDPM sampling. By replacing the traditional accuracy requirement with a quantile-based measure, the authors achieve polylogarithmic dependence on the target Wasserstein accuracy and exponential improvements in depth and Wasserstein error, while incurring a modest increase in the -dependence. They prove that a neural network class can learn scores to sufficient quantile accuracy using samples, and that this is enough for the DDPM reverse-SDE sampling to yield a distribution within TV of , leveraging a careful combination of training concentration and a Girsanov-based analysis for sampling. The paper also presents information-theoretic hardness results for learning, motivating the new quantile-centric approach and providing a sharper lens on diffusion-model training theory with potential practical impact for training efficiency in high-dimensional generative modeling.

Abstract

Diffusion models have become the most popular approach to deep generative modeling of images, largely due to their empirical performance and reliability. From a theoretical standpoint, a number of recent works have studied the iteration complexity of sampling, assuming access to an accurate diffusion model. In this work, we focus on understanding the sample complexity of training such a model; how many samples are needed to learn an accurate diffusion model using a sufficiently expressive neural network? Prior work showed bounds polynomial in the dimension, desired Total Variation error, and Wasserstein error. We show an exponential improvement in the dependence on Wasserstein error and depth, along with improved dependencies on other relevant parameters.
Paper Structure (32 sections, 34 theorems, 175 equations, 2 figures, 1 algorithm)

This paper contains 32 sections, 34 theorems, 175 equations, 2 figures, 1 algorithm.

Key Result

Theorem 1.2

In set:1, suppose assumptions A1 and A2 hold. For any $\varepsilon\in(0,1)$ and $\gamma\in(0,1)$, there exists a discretization schedule such that the score functions trained from i.i.d. samples of $q_0$ yield, with $0.99$ probability over training draw, a DDPM sampler whose output distribution is $\varepsilon$-close in TV to a distribution $\gamma m_2$-close to $q$ in 2-Wasserstein.

Figures (2)

  • Figure 1: Given $o\left( \frac{1}{\eta} \right)$ samples from either $p_1 = (1 - \eta)\mathcal{N}(0, 1) + \eta \mathcal{N}(-R, 1)$, or $p_2 = (1 - \eta) \mathcal{N}(0, 1) + \eta \mathcal{N}(R, 1)$ we will only see samples from the main Gaussian with high probability, and cannot distinguish between them. However, if we pick the wrong score function, the $L^2$ error incurred is large - about $\eta R^2$. On the right, we take $\eta = 0.001, R = 10000, \delta = 0.01$. We plot the probability that the ERM has error larger than $0$ in the $L^2$ sense, and our $D_p^\delta$ sense.
  • Figure 2: For $m$ samples from $\mathcal{N}(0, 1)$, consider the score $\widehat{s}$ of the mixture $\eta \mathcal{N}(0, 1) + (1 - \eta) \mathcal{N}(R, 1)$ above with $\eta$ is chosen so that $\widehat{s}(10 \sqrt{\log m}) = 0$. For this $\widehat{s}$, the score-matching objective is close to $0$, while the squared $L^2$ error is $\Omega\left(\frac{R^2}{m} \right)$.

Theorems & Definitions (53)

  • Theorem 1.2
  • Lemma 1.2
  • Lemma 1.2: Main Lemma
  • Lemma 4.0
  • Lemma 4.0
  • Lemma A.1: Main Lemma, Quantitative Version
  • Lemma A.2: Score Estimation for Finite Function Class
  • proof
  • Lemma A.3
  • proof
  • ...and 43 more