Table of Contents
Fetching ...

On the Feature Learning in Diffusion Models

Andi Han, Wei Huang, Yuan Cao, Difan Zou

TL;DR

This work develops a theoretical framework to analyze how diffusion models learn features during training and contrasts it with supervised classification. In a two-patch data setting with orthogonal signal and noise, diffusion models exhibit balanced learning of signal and noise, with the stationary weight alignment obeying a ratio of Θ(n SNR^2). Classification, by contrast, shows a phase transition governed by n SNR^2, where high values favor signal learning and low values favor memorizing noise. Experiments on synthetic data and real-world Noisy-MNIST/MNIST demonstrate the predicted dynamics: diffusion maintains balanced feature learning across regimes, while classification tends to shortcut by focusing on either signal or noise, depending on the SNR. The findings offer insights into the robustness and representation-learning power of diffusion models and point to extensions to more complex data, conditional/latent diffusion, and other generative paradigms.

Abstract

The predominant success of diffusion models in generative modeling has spurred significant interest in understanding their theoretical foundations. In this work, we propose a feature learning framework aimed at analyzing and comparing the training dynamics of diffusion models with those of traditional classification models. Our theoretical analysis demonstrates that diffusion models, due to the denoising objective, are encouraged to learn more balanced and comprehensive representations of the data. In contrast, neural networks with a similar architecture trained for classification tend to prioritize learning specific patterns in the data, often focusing on easy-to-learn components. To support these theoretical insights, we conduct several experiments on both synthetic and real-world datasets, which empirically validate our findings and highlight the distinct feature learning dynamics in diffusion models compared to classification.

On the Feature Learning in Diffusion Models

TL;DR

This work develops a theoretical framework to analyze how diffusion models learn features during training and contrasts it with supervised classification. In a two-patch data setting with orthogonal signal and noise, diffusion models exhibit balanced learning of signal and noise, with the stationary weight alignment obeying a ratio of Θ(n SNR^2). Classification, by contrast, shows a phase transition governed by n SNR^2, where high values favor signal learning and low values favor memorizing noise. Experiments on synthetic data and real-world Noisy-MNIST/MNIST demonstrate the predicted dynamics: diffusion maintains balanced feature learning across regimes, while classification tends to shortcut by focusing on either signal or noise, depending on the SNR. The findings offer insights into the robustness and representation-learning power of diffusion models and point to extensions to more complex data, conditional/latent diffusion, and other generative paradigms.

Abstract

The predominant success of diffusion models in generative modeling has spurred significant interest in understanding their theoretical foundations. In this work, we propose a feature learning framework aimed at analyzing and comparing the training dynamics of diffusion models with those of traditional classification models. Our theoretical analysis demonstrates that diffusion models, due to the denoising objective, are encouraged to learn more balanced and comprehensive representations of the data. In contrast, neural networks with a similar architecture trained for classification tend to prioritize learning specific patterns in the data, often focusing on easy-to-learn components. To support these theoretical insights, we conduct several experiments on both synthetic and real-world datasets, which empirically validate our findings and highlight the distinct feature learning dynamics in diffusion models compared to classification.

Paper Structure

This paper contains 41 sections, 42 theorems, 229 equations, 15 figures.

Key Result

Theorem 1.1

Let ${\mathrm{SNR}} \coloneqq \| {\boldsymbol{\mu}}\|/(\sigma_\xi \sqrt{d})$ be the signal-to-noise ratio. We can show

Figures (15)

  • Figure 1: Illustration of the ratio of signal learning to noise learning when varying $n \cdot {\mathrm{SNR}}^2$, where ${\mathrm{SNR}} \coloneqq \| {\boldsymbol{\mu}}\|/(\sigma_\xi \sqrt{d})$. We show diffusion model tends to study more balanced signal and noise while classification has a sharp phase transition and tends to focus on learning either signal or noise.
  • Figure 2: Experiments on the synthetic dataset with both low SNR ($n \cdot {\mathrm{SNR}}^2 = 0.75$) and high SNR ($n \cdot {\mathrm{SNR}}^2 = 6.75$). In the low SNR setting, we see noise learning quickly dominates signal learning for the classification task and in the high SNR setting, signal learning quickly dominates noise learning. Meanwhile diffusion model converges to a stationary point that with signal-to-noise learning ratio respects the order of $n \cdot {\mathrm{SNR}}^2$. More experimental results on additional SNR values are in Appendix \ref{['app:additional_snr']}.
  • Figure 3: Experiments on Noisy-MNIST with $\widetilde{{\mathrm{SNR}}} = 0.1$. (First row): Test Noisy-MNIST images; (Second row): Illustration of input gradient, i.e., $\nabla_{\mathbf x} F_{+1}({\mathbf W}, {\mathbf x})$ when $y = 1$ and $\nabla_{\mathbf x} F_{-1}({\mathbf W}, {\mathbf x})$ when $y = 0$. (Third row): denoised image from diffusion model. In this low-SNR case, we see classification tends to predominately learn noise while diffusion learns both signals and noise.
  • Figure 4: Experiments on Noisy-MNIST with $\widetilde{{\mathrm{SNR}}} = 0.1$. (a) Train loss for classification. (b) Train loss for diffusion model. (c) Feture learning dynamics.
  • Figure 5: Experiments on the synthetic dataset with both low SNR ($n \cdot {\mathrm{SNR}}^2 = 0.75$) and high SNR ($n \cdot {\mathrm{SNR}}^2 = 6.75$).
  • ...and 10 more figures

Theorems & Definitions (80)

  • Theorem 1.1: Informal
  • Definition 2.1: Data distribution
  • Theorem 3.1: Diffusion model
  • Theorem 3.2: Classification
  • Lemma 4.1
  • Lemma 4.2
  • Theorem 4.1: Informal
  • Lemma 4.3
  • Remark 4.1
  • Lemma B.1
  • ...and 70 more