Table of Contents
Fetching ...

A solvable model of learning generative diffusion: theory and insights

Hugo Cui, Cengiz Pehlevan, Yue M. Lu

TL;DR

The paper develops a solvable model for learning generative diffusion through a two-layer denoising autoencoder trained with online SGD, targeting high-dimensional densities with low-dimensional manifolds. It derives a tight, two-tier asymptotic description: first, deterministic ODEs for low-dimensional weight-summary statistics that capture learning dynamics; second, a reduced, low-dimensional transport SDE describing how generated samples evolve, with a fixed-projection corollary. This yields sharp, interpretable characterizations of low-dimensional projections of the generated density and illuminates training-time evolution, including realistic targets like Gaussian mixtures and MNIST. The study reveals architectural biases in the DAE that can cause mode collapse and, if synthetic data are reused, model collapse, underscoring the critical role of network design in diffusion-based generative modeling and its implications for reuse of generated data.

Abstract

In this manuscript, we consider the problem of learning a flow or diffusion-based generative model parametrized by a two-layer auto-encoder, trained with online stochastic gradient descent, on a high-dimensional target density with an underlying low-dimensional manifold structure. We derive a tight asymptotic characterization of low-dimensional projections of the distribution of samples generated by the learned model, ascertaining in particular its dependence on the number of training samples. Building on this analysis, we discuss how mode collapse can arise, and lead to model collapse when the generative model is re-trained on generated synthetic data.

A solvable model of learning generative diffusion: theory and insights

TL;DR

The paper develops a solvable model for learning generative diffusion through a two-layer denoising autoencoder trained with online SGD, targeting high-dimensional densities with low-dimensional manifolds. It derives a tight, two-tier asymptotic description: first, deterministic ODEs for low-dimensional weight-summary statistics that capture learning dynamics; second, a reduced, low-dimensional transport SDE describing how generated samples evolve, with a fixed-projection corollary. This yields sharp, interpretable characterizations of low-dimensional projections of the generated density and illuminates training-time evolution, including realistic targets like Gaussian mixtures and MNIST. The study reveals architectural biases in the DAE that can cause mode collapse and, if synthetic data are reused, model collapse, underscoring the critical role of network design in diffusion-based generative modeling and its implications for reuse of generated data.

Abstract

In this manuscript, we consider the problem of learning a flow or diffusion-based generative model parametrized by a two-layer auto-encoder, trained with online stochastic gradient descent, on a high-dimensional target density with an underlying low-dimensional manifold structure. We derive a tight asymptotic characterization of low-dimensional projections of the distribution of samples generated by the learned model, ascertaining in particular its dependence on the number of training samples. Building on this analysis, we discuss how mode collapse can arise, and lead to model collapse when the generative model is re-trained on generated synthetic data.
Paper Structure (59 sections, 1 theorem, 69 equations, 14 figures)

This paper contains 59 sections, 1 theorem, 69 equations, 14 figures.

Key Result

Corollary 2.3

(Projected generated density) The law of the projection $E^\top X_t$ of a sample $X_t$ in the space of interest $\mathcal{E}$ is given by where the law of $Z_t$ is characterized in Result res:transport by the SDE eq:transport_Z, and the summary statistics $\mathcal{Q}_\tau, G_\tau$ are characterized in Result res:training. $\mathcal{Q}_\tau^+$ denotes the Moore-Penrose pseudo-inverse of $\mathcal

Figures (14)

  • Figure 1: Evolution of the summary statistics $M_\tau, \mathcal{Q}_\tau$ and of the skip connection strength $b_\tau$ as a function of the training time $\tau$, for $\sigma=\tanh, r=2, p_t=0,\alpha_t=1-t,\beta_t=t,\mathcal{G}=\{1/2\}$. The target density is a trimodal Gaussian mixture $\rho=1/3\mathcal{N}(\mu_1,I_d)+1/3\mathcal{N}(\mu_2,I_d)+1/3\mathcal{N}(\mu_3,I_d).$ Solid lines: numerical experiments in dimension $d=1000$. Dashed: theoretical characterization \ref{['eq:ODEs']} of Result \ref{['res:training']}.
  • Figure 2: (Left) Evolution of the projected density $\Pi_{\mathcal{E}}\hat{\rho}_\tau$ generated by a DAE \ref{['eq:AE']} with $r=1$ hidden unit and $\sigma=\tanh$ activation, trained on a bimodal Gaussian mixture, with $\eta=0.2,\lambda=1.5, \epsilon_t=0, p_t=0,\alpha_t=1-t, \beta_t=t, \mathcal{G}=\{0.5\}$. The generative SDE \ref{['eq:sampling_SDE_AE']} was run up to $t=0.9$, and the subspace $\mathcal{E}$ is a plane containing the centroid of the target density. Different panels correspond to different training times $\tau$. Blue contours: contour levels of the theoretical prediction of Corollary \ref{['res:projection']} for the density $\Pi_{\mathcal{E}}\hat{\rho}_\tau$. Colormap: numerical experiments in large but finite dimension $d=1000$. Green contours: contour levels of the target density $\rho$. (Right) Hellinger distance between the target and generated densities, projected in the space spanned by the centroid, as a function of the training time $\tau$.
  • Figure 3: (Left) Evolution of the density $\Pi_{\mathcal{E}}\hat{\rho}_\tau$ generated by a DAE \ref{['eq:AE']} with $r=2$ hidden units and $\sigma=\mathrm{tanh}$ activation, trained on a Gaussian density with the MNIST sevens covariance, with $\eta=0.2,\lambda=.784, \epsilon_t=p_t=0,\alpha_t=1-t, \beta_t=t, \mathcal{G}=\{1/2\}$. The generative SDE \ref{['eq:sampling_SDE_AE']} was run up to $t=0.98$, and the subspace $\mathcal{E}$ is spanned by principal components of the target density. Different panels correspond to different training times $\tau$. Blue contours: contour levels of the theoretical prediction of Corollary \ref{['res:projection']} for the density $\hat{\rho}_\tau$. Colormap: numerical experiments. Green contours: contour levels of the target density $\rho$. (Right) Samples from the generated density $\hat{\rho}_\tau$, from a common initialization $X_0$ of the generative SDE \ref{['eq:sampling_SDE_AE']}, as a function of the training time $\tau$.
  • Figure 4: (left) Target density $\rho$ corresponding to a Gaussian density equipped with the covariance of the distribution of MNIST sevens (middle) generated density $\Pi_{\mathcal{E}}\hat{\rho}^{(1)}_\tau$ (right) second generation density $\Pi_{\mathcal{E}}\hat{\rho}^{(2)}_\tau$ obtained by training the generative model \ref{['eq:AE']} on the synthetic distribution $\hat{\rho}^{(1)}_\tau$. Blue contours: contour levels of the theoretical prediction of Corollary \ref{['res:projection']}. Colormap: numerical experiments in large but finite dimension $d=1000$. Green contours: contour levels of the target density $\rho$. At each successive generations, the same model specifications $\tau=2.8, r=2, \sigma=\tanh, \eta=0.2,\lambda=.784, \epsilon_t=p_t=0,\alpha_t=1-t, \beta_t=t, \mathcal{G}=\{1/2\}$ were employed. The generative SDEs \ref{['eq:sampling_SDE_AE']} were run up to $t=0.98$ at reach generation. Finally, the subspace $\mathcal{E}$ is spanned by principal components of the target density.
  • Figure 5: Evolution of the summary statistics \ref{['eq:summ_stats']}$M$ (left), $Q$ (middle) and skip connection strength $b$ (right), characterizing the dynamics of the AE parameters \ref{['eq:AE']} under SGD dynamics \ref{['eq:SGD']}. Parameters $\sigma=\tanh, r=2, \lambda=0,\eta=0.2,\mathcal{G}=\{1/2\}$ were used, and the target density $\rho$ was taken to be a Gaussian mixture with three isotropic clusters (see also Fig. \ref{['fig:Compo_evolution']} in the main text). The weight vectors were initialized along the centroids of the target density, with norm $0.1$, while the initial skip connection strength is $b_0=0$. Dashed lines: theoretical characterization of Result \ref{['res:training']}. Continuous lines: numerical experiments in $d=1000$, for a single run.
  • ...and 9 more figures

Theorems & Definitions (2)

  • Corollary 2.3
  • Remark 4.1