Table of Contents
Fetching ...

Understanding Representation Dynamics of Diffusion Models via Low-Dimensional Modeling

Xiao Li, Zekai Zhang, Xiang Li, Siyi Chen, Zhihui Zhu, Peng Wang, Qing Qu

TL;DR

Diffusion models exhibit unimodal representation dynamics where feature quality peaks at an intermediate noise level; the paper explains this via a low-rank mixture data model (MoLRG) and a tractable denoiser parameterization. It develops a theoretical framework with a $SNR$-based representation metric and proves that the optimal denoiser aligns with ground-truth subspaces, producing unimodal $SNR$ curves across diffusion time. Empirically, unimodal dynamics predict model generalization in classification and track transitions to memorization with dataset size, model capacity, and training duration. These results bridge distribution learning and representation learning in diffusion models, offering a principled basis for early stopping and representation-based evaluations.

Abstract

Diffusion models, though originally designed for generative tasks, have demonstrated impressive self-supervised representation learning capabilities. A particularly intriguing phenomenon in these models is the emergence of unimodal representation dynamics, where the quality of learned features peaks at an intermediate noise level. In this work, we conduct a comprehensive theoretical and empirical investigation of this phenomenon. Leveraging the inherent low-dimensionality structure of image data, we theoretically demonstrate that the unimodal dynamic emerges when the diffusion model successfully captures the underlying data distribution. The unimodality arises from an interplay between denoising strength and class confidence across noise scales. Empirically, we further show that, in classification tasks, the presence of unimodal dynamics reliably reflects the generalization of the diffusion model: it emerges when the model generates novel images and gradually transitions to a monotonically decreasing curve as the model begins to memorize the training data.

Understanding Representation Dynamics of Diffusion Models via Low-Dimensional Modeling

TL;DR

Diffusion models exhibit unimodal representation dynamics where feature quality peaks at an intermediate noise level; the paper explains this via a low-rank mixture data model (MoLRG) and a tractable denoiser parameterization. It develops a theoretical framework with a -based representation metric and proves that the optimal denoiser aligns with ground-truth subspaces, producing unimodal curves across diffusion time. Empirically, unimodal dynamics predict model generalization in classification and track transitions to memorization with dataset size, model capacity, and training duration. These results bridge distribution learning and representation learning in diffusion models, offering a principled basis for early stopping and representation-based evaluations.

Abstract

Diffusion models, though originally designed for generative tasks, have demonstrated impressive self-supervised representation learning capabilities. A particularly intriguing phenomenon in these models is the emergence of unimodal representation dynamics, where the quality of learned features peaks at an intermediate noise level. In this work, we conduct a comprehensive theoretical and empirical investigation of this phenomenon. Leveraging the inherent low-dimensionality structure of image data, we theoretically demonstrate that the unimodal dynamic emerges when the diffusion model successfully captures the underlying data distribution. The unimodality arises from an interplay between denoising strength and class confidence across noise scales. Empirically, we further show that, in classification tasks, the presence of unimodal dynamics reliably reflects the generalization of the diffusion model: it emerges when the model generates novel images and gradually transitions to a monotonically decreasing curve as the model begins to memorize the training data.

Paper Structure

This paper contains 48 sections, 6 theorems, 45 equations, 15 figures, 9 tables.

Key Result

Proposition 1

Suppose the data $\bm x_0$ is drawn from a noisy MoLRG data distribution with $K$-class and noise level $\delta$ introduced in assum:subspace. Then the optimal $\{\bm U\}$ minimizing the loss eq:ddpm_loss is the ground truth basis defined in eq:MoG noise, and the optimal DAE $\hat{\bm x}_{\bm \theta where $w_l^\star(\bm{x}_t, t)$ are the coefficients in eq:net_attn when $\{\bm U\} = \{\bm U_l^\sta

Figures (15)

  • Figure 1: Unimodal representation dynamics in diffusion-based representation learning tasks. This unimodal representation pattern has been previously observed in diffusion-based representation learning tasks; see baranchuk2021labelxiang2023denoisingtang2023emergent. To verify this, we train diffusion models on various datasets and evaluate downstream performance using noisy images $\bm{x}_t$ at different timesteps $t$. In both classification and segmentation tasks, the performance consistently follows a unimodal trend, peaking at intermediate noise levels. In (b), "mIoU" denotes mean Intersection over Union, a standard metric used in segmentation tasks.
  • Figure 2: An illustration of MoLRG with different noise levels. We visualize samples drawn from noisy MoLRG with noise levels $\delta = 0.1,\;0.3$ and $K=3$.
  • Figure 3: Feature probing accuracy and associated $\mathrm{SNR}$ dynamics in MoLRG data. In panel(a) we plot the probing accuracy and $\mathrm{SNR}$ with the feature obtained from a learned DAE $\bm{\hat{x}_\theta}$, both of which exhibit a consistent unimodal pattern. The DAE is trained on a 3-class MoLRG dataset with data dimension $n=50$, subspace dimension $d=5$, and noise scale $\delta=0.2$. Additionally, in panel(b) we include the optimal $\mathrm{SNR}$ calculated from the optimal DAE $\bm{\hat{x}_\theta}^\star$ and the derived approximation in \ref{['lem:main']} as a reference.
  • Figure 4: Illustration of the interplay between the denoising rate and the class confidence rate. The settings follow \ref{['fig:csnr_molrg_match']}.
  • Figure 5: Dynamics of feature accuracy and associated $\mathrm{SNR}$ on CIFAR10 and TinyImageNet. Feature accuracy is plotted alongside $\mathrm{SNR}(\hat{\bm x}_{\bm \theta},t)$. Feature accuracy is evaluated on the test set, while the empirical $\mathrm{SNR}$ is computed from the training set. Both exhibit an aligning unimodal pattern. We use released EDM models karras2022elucidating trained on the CIFAR10 krizhevsky2009learning and ImageNet deng2009imagenet datasets, evaluating them on CIFAR10 and TinyImageNet cs231n, respectively. To compute $\mathrm{SNR}$ , we apply PCA on the CIFAR10/TinyImageNet features to extract the basis $\bm{U}_l$s. Further details can be found in \ref{['app:exp_detail']}.
  • ...and 10 more figures

Theorems & Definitions (11)

  • Definition 1
  • Proposition 1
  • Theorem 1
  • Proposition 2
  • proof
  • Theorem 2
  • proof
  • Lemma 1
  • proof
  • Lemma 2
  • ...and 1 more