Adaptive Non-uniform Timestep Sampling for Accelerating Diffusion Model Training
Myunsoo Kim, Donghyeon Ki, Seong-Woong Shim, Byung-Jun Lee
TL;DR
The paper tackles the high computational cost of training diffusion models by identifying non-uniform gradient variance across timesteps as a key bottleneck. It introduces an online, learning-based adaptive timestep sampler $\pi_\phi$ that prioritizes timesteps whose gradient updates most reduce the variational lower bound $\mathcal{L}_{VLB}$, using a surrogate Delta $\tilde{\Delta}_k^t$ computed from a small subset of timesteps. Through extensive experiments across CIFAR-10, CelebA-HQ, and ImageNet with diverse schedules and backbones, the method demonstrates faster convergence and improved final fidelity (lower FID) than heuristic acceleration strategies, while remaining robust to scheduling and architecture changes and effectively combining with existing heuristics. The approach offers a practical path to faster, more robust diffusion-model training and broad applicability across domains, with a clear extension potential to score-based diffusion models in future work.
Abstract
As a highly expressive generative model, diffusion models have demonstrated exceptional success across various domains, including image generation, natural language processing, and combinatorial optimization. However, as data distributions grow more complex, training these models to convergence becomes increasingly computationally intensive. While diffusion models are typically trained using uniform timestep sampling, our research shows that the variance in stochastic gradients varies significantly across timesteps, with high-variance timesteps becoming bottlenecks that hinder faster convergence. To address this issue, we introduce a non-uniform timestep sampling method that prioritizes these more critical timesteps. Our method tracks the impact of gradient updates on the objective for each timestep, adaptively selecting those most likely to minimize the objective effectively. Experimental results demonstrate that this approach not only accelerates the training process, but also leads to improved performance at convergence. Furthermore, our method shows robust performance across various datasets, scheduling strategies, and diffusion architectures, outperforming previously proposed timestep sampling and weighting heuristics that lack this degree of robustness.
