Unlocking Dataset Distillation with Diffusion Models
Brian B. Moser, Federico Raue, Sebastian Palacio, Stanislav Frolov, Andreas Dengel
TL;DR
LD3M introduces a gradient-enabled approach to diffusion-based dataset distillation by injecting linearly decaying residuals from the initial latent state into every reverse diffusion step, enabling end-to-end optimization of learnable latents $\mathcal{Z}$ and conditioning $\mathbf{c}$. Using a pre-trained latent diffusion model without fine-tuning, LD3M achieves significant gains over state-of-the-art GAN-prior methods across ImageNet subsets at both $128\times128$ and $256\times256$ resolutions, with up to $+4.8$ percentage points improvement at IPC=1 and $+4.2$ at IPC=10. The method also demonstrates robustness to initialization and offers practical benefits in memory efficiency and distillation speed via gradient checkpointing. This work paves the way for leveraging diffusion priors in compact, high-quality distilled datasets suitable for cross-architecture transfer learning and privacy considerations.
Abstract
Dataset distillation seeks to condense datasets into smaller but highly representative synthetic samples. While diffusion models now lead all generative benchmarks, current distillation methods avoid them and rely instead on GANs or autoencoders, or, at best, sampling from a fixed diffusion prior. This trend arises because naive backpropagation through the long denoising chain leads to vanishing gradients, which prevents effective synthetic sample optimization. To address this limitation, we introduce Latent Dataset Distillation with Diffusion Models (LD3M), the first method to learn gradient-based distilled latents and class embeddings end-to-end through a pre-trained latent diffusion model. A linearly decaying skip connection, injected from the initial noisy state into every reverse step, preserves the gradient signal across dozens of timesteps without requiring diffusion weight fine-tuning. Across multiple ImageNet subsets at 128x128 and 256x256, LD3M improves downstream accuracy by up to 4.8 percentage points (1 IPC) and 4.2 points (10 IPC) over the prior state-of-the-art. The code for LD3M is provided at https://github.com/Brian-Moser/prune_and_distill.
