On the Feature Learning in Diffusion Models
Andi Han, Wei Huang, Yuan Cao, Difan Zou
TL;DR
This work develops a theoretical framework to analyze how diffusion models learn features during training and contrasts it with supervised classification. In a two-patch data setting with orthogonal signal and noise, diffusion models exhibit balanced learning of signal and noise, with the stationary weight alignment obeying a ratio of Θ(n SNR^2). Classification, by contrast, shows a phase transition governed by n SNR^2, where high values favor signal learning and low values favor memorizing noise. Experiments on synthetic data and real-world Noisy-MNIST/MNIST demonstrate the predicted dynamics: diffusion maintains balanced feature learning across regimes, while classification tends to shortcut by focusing on either signal or noise, depending on the SNR. The findings offer insights into the robustness and representation-learning power of diffusion models and point to extensions to more complex data, conditional/latent diffusion, and other generative paradigms.
Abstract
The predominant success of diffusion models in generative modeling has spurred significant interest in understanding their theoretical foundations. In this work, we propose a feature learning framework aimed at analyzing and comparing the training dynamics of diffusion models with those of traditional classification models. Our theoretical analysis demonstrates that diffusion models, due to the denoising objective, are encouraged to learn more balanced and comprehensive representations of the data. In contrast, neural networks with a similar architecture trained for classification tend to prioritize learning specific patterns in the data, often focusing on easy-to-learn components. To support these theoretical insights, we conduct several experiments on both synthetic and real-world datasets, which empirically validate our findings and highlight the distinct feature learning dynamics in diffusion models compared to classification.
