Table of Contents
Fetching ...

A theory of learning data statistics in diffusion models, from easy to hard

Lorenzo Bardone, Claudia Merger, Sebastian Goldt

Abstract

While diffusion models have emerged as a powerful class of generative models, their learning dynamics remain poorly understood. We address this issue first by empirically showing that standard diffusion models trained on natural images exhibit a distributional simplicity bias, learning simple, pair-wise input statistics before specializing to higher-order correlations. We reproduce this behaviour in simple denoisers trained on a minimal data model, the mixed cumulant model, where we precisely control both pair-wise and higher-order correlations of the inputs. We identify a scalar invariant of the model that governs the sample complexity of learning pair-wise and higher-order correlations that we call the diffusion information exponent, in analogy to related invariants in different learning paradigms. Using this invariant, we prove that the denoiser learns simple, pair-wise statistics of the inputs at linear sample complexity, while more complex higher-order statistics, such as the fourth cumulant, require at least cubic sample complexity. We also prove that the sample complexity of learning the fourth cumulant is linear if pair-wise and higher-order statistics share a correlated latent structure. Our work describes a key mechanism for how diffusion models can learn distributions of increasing complexity.

A theory of learning data statistics in diffusion models, from easy to hard

Abstract

While diffusion models have emerged as a powerful class of generative models, their learning dynamics remain poorly understood. We address this issue first by empirically showing that standard diffusion models trained on natural images exhibit a distributional simplicity bias, learning simple, pair-wise input statistics before specializing to higher-order correlations. We reproduce this behaviour in simple denoisers trained on a minimal data model, the mixed cumulant model, where we precisely control both pair-wise and higher-order correlations of the inputs. We identify a scalar invariant of the model that governs the sample complexity of learning pair-wise and higher-order correlations that we call the diffusion information exponent, in analogy to related invariants in different learning paradigms. Using this invariant, we prove that the denoiser learns simple, pair-wise statistics of the inputs at linear sample complexity, while more complex higher-order statistics, such as the fourth cumulant, require at least cubic sample complexity. We also prove that the sample complexity of learning the fourth cumulant is linear if pair-wise and higher-order statistics share a correlated latent structure. Our work describes a key mechanism for how diffusion models can learn distributions of increasing complexity.
Paper Structure (30 sections, 6 theorems, 72 equations, 7 figures)

This paper contains 30 sections, 6 theorems, 72 equations, 7 figures.

Key Result

Proposition 1

Assume that $L_t(x\cdot v)$ is the likelihood ratio of a sub-Gaussian random variable, and $\sigma$ an activation function such that $F_\sigma$ satisfies assumption:A and assumption:B. Denote with $k^*$ the information exponent of the loss $\mathcal{L}$ and let $\hat{n}(d,k^*)$ be a sample complexit then the application of $\hat{n}(d,k^*)$ steps of projected gradient descent with step size $\eta_d

Figures (7)

  • Figure 1: Sequential learning in diffusion models. a) Test loss of diffusion model and loss on CIFAR-10 clones during training. Vertical dotted lines mark training stages of images generated from the model shown in panels d)-g). All curves are averages over 3 initializations of the network models and $5\cdot10^3$ test data. Shaded areas report standard deviation over random initialization. Panel b) reports the same as a), but for denoising samples with fixed level of noise $x=e^{-t}x_0+\sqrt{1-e^{-2t}}z$ where $x_0$ is a data point and noise $z \sim {\mathcal{N}\left(0,\text{Id}\right)}$. c) Test loss of a neural network trained on the mixed cumulant model with dimension $d=10^2$ a fixed level of noise, evaluated on clones of the data set. All curves in panel c) are averages over $5$ initializations of the network model and $10^4$ test data. d)-g) Samples generated from U-nets ronneberger_u-net_2015 on CIFAR-10 at various training stages.
  • Figure 2: Examples of the contraction term $\Lambda$ for different choices of activation $\sigma$. $\sigma^*$ denotes the matched functional form of the score \ref{['eq:score_spiked']} for different values of the diffusion time $t$.
  • Figure 3: Normalized overlap of first-layer weights of neural networks of varying depth trained with Adam on inputs drawn from the mixed cumulant model, \ref{['eq:MCM']}, at $d=100$. All curves are averages over $5$ random initializations of the neural networks.
  • Figure A.1: Samples from the different "clones" as well as the test data set. a) shows images drawn from the mean clone which follows a Gaussian distribution with matching mean to the CIFAR-10 dataset and identity covariance. In b), we additionally match the covariance matrix of the Gaussian distribution to the CIFAR-10 dataset. c) shows 9 images from the CIFAR-10 dataset.
  • Figure A.2: Samples from the different "clones" as well as the test data set. a) shows images drawn from the mean clone which follows a Gaussian distribution with matching mean to the CelebA dataset and identity covariance. In b), we additionally match the covariance matrix of the Gaussian distribution to the CelebA dataset. c) shows 9 images from the CelebA dataset.
  • ...and 2 more figures

Theorems & Definitions (15)

  • Proposition 1: Positive result
  • Proposition 2: Negative result
  • proof
  • Proposition 3
  • Proposition 4
  • Definition 1
  • Definition 2: Hermite expansion
  • Definition 3: Information exponent
  • Lemma 1: Stein lemma
  • proof
  • ...and 5 more