Table of Contents
Fetching ...

Unlocking Dataset Distillation with Diffusion Models

Brian B. Moser, Federico Raue, Sebastian Palacio, Stanislav Frolov, Andreas Dengel

TL;DR

LD3M introduces a gradient-enabled approach to diffusion-based dataset distillation by injecting linearly decaying residuals from the initial latent state into every reverse diffusion step, enabling end-to-end optimization of learnable latents $\mathcal{Z}$ and conditioning $\mathbf{c}$. Using a pre-trained latent diffusion model without fine-tuning, LD3M achieves significant gains over state-of-the-art GAN-prior methods across ImageNet subsets at both $128\times128$ and $256\times256$ resolutions, with up to $+4.8$ percentage points improvement at IPC=1 and $+4.2$ at IPC=10. The method also demonstrates robustness to initialization and offers practical benefits in memory efficiency and distillation speed via gradient checkpointing. This work paves the way for leveraging diffusion priors in compact, high-quality distilled datasets suitable for cross-architecture transfer learning and privacy considerations.

Abstract

Dataset distillation seeks to condense datasets into smaller but highly representative synthetic samples. While diffusion models now lead all generative benchmarks, current distillation methods avoid them and rely instead on GANs or autoencoders, or, at best, sampling from a fixed diffusion prior. This trend arises because naive backpropagation through the long denoising chain leads to vanishing gradients, which prevents effective synthetic sample optimization. To address this limitation, we introduce Latent Dataset Distillation with Diffusion Models (LD3M), the first method to learn gradient-based distilled latents and class embeddings end-to-end through a pre-trained latent diffusion model. A linearly decaying skip connection, injected from the initial noisy state into every reverse step, preserves the gradient signal across dozens of timesteps without requiring diffusion weight fine-tuning. Across multiple ImageNet subsets at 128x128 and 256x256, LD3M improves downstream accuracy by up to 4.8 percentage points (1 IPC) and 4.2 points (10 IPC) over the prior state-of-the-art. The code for LD3M is provided at https://github.com/Brian-Moser/prune_and_distill.

Unlocking Dataset Distillation with Diffusion Models

TL;DR

LD3M introduces a gradient-enabled approach to diffusion-based dataset distillation by injecting linearly decaying residuals from the initial latent state into every reverse diffusion step, enabling end-to-end optimization of learnable latents and conditioning . Using a pre-trained latent diffusion model without fine-tuning, LD3M achieves significant gains over state-of-the-art GAN-prior methods across ImageNet subsets at both and resolutions, with up to percentage points improvement at IPC=1 and at IPC=10. The method also demonstrates robustness to initialization and offers practical benefits in memory efficiency and distillation speed via gradient checkpointing. This work paves the way for leveraging diffusion priors in compact, high-quality distilled datasets suitable for cross-architecture transfer learning and privacy considerations.

Abstract

Dataset distillation seeks to condense datasets into smaller but highly representative synthetic samples. While diffusion models now lead all generative benchmarks, current distillation methods avoid them and rely instead on GANs or autoencoders, or, at best, sampling from a fixed diffusion prior. This trend arises because naive backpropagation through the long denoising chain leads to vanishing gradients, which prevents effective synthetic sample optimization. To address this limitation, we introduce Latent Dataset Distillation with Diffusion Models (LD3M), the first method to learn gradient-based distilled latents and class embeddings end-to-end through a pre-trained latent diffusion model. A linearly decaying skip connection, injected from the initial noisy state into every reverse step, preserves the gradient signal across dozens of timesteps without requiring diffusion weight fine-tuning. Across multiple ImageNet subsets at 128x128 and 256x256, LD3M improves downstream accuracy by up to 4.8 percentage points (1 IPC) and 4.2 points (10 IPC) over the prior state-of-the-art. The code for LD3M is provided at https://github.com/Brian-Moser/prune_and_distill.
Paper Structure (18 sections, 10 equations, 18 figures, 12 tables, 1 algorithm)

This paper contains 18 sections, 10 equations, 18 figures, 12 tables, 1 algorithm.

Figures (18)

  • Figure 1: The LD3M Framework. Learnable latent codes $\mathcal{Z}$ and conditioning codes $\mathbf{c}$ are optimized. $\mathcal{Z}$ is noised to initialize the reverse diffusion at $\mathbf{z}_T$. A pre-trained LDM denoiser iteratively refines the state ($\mathbf{z}_t \to \mathbf{z}_{t-1}$). Key innovation: Residual connections (red arrows) inject $\mathbf{z}_T$ with linearly decaying weight into each step (\ref{['eq:intLatent']}), enhancing gradient flow. The final latent $\mathbf{z}_0$ is decoded ($\mathcal{D}$) into images $\mathcal{S}$, which are optimized using a standard distillation algorithm.
  • Figure 2: Visual comparison of LD3M versus GLaD (MTT, ImageNette, 1K iter.). GLaD outputs tend toward smooth, photo-realistic textures but can blur class-defining details, whereas LD3M produces bolder, higher-contrast shapes that highlight key discriminative features (e.g., wing contours, beak outline). This abstraction trade-off suggests LD3M prioritizes core class signals over pixel-perfect fidelity, which empirically enhances downstream model generalization, contrasting claims made by sampling-based methods like D4M su2024d.
  • Figure 3: Example $256\times256$ images of a distilled class (ImageNet-B: Lorikeet) with differently initialized generators GLaD and LD3M. The various initializations, i.e., which dataset was used for training the generators, are denoted at the bottom. We used DC as distillation algorithm.
  • Figure 4: Accuracy vs. Distillation Time Trade-off with Diffusion Steps $T$ (ImageNet A-E avg., MTT, IPC=1). LD3M performance (blue line, mean $\pm$ std) peaks around $T=35$. GLaD baseline (dashed lines) and optimal trade-off (X) shown for reference.
  • Figure 5: Influence of our modified reverse process in a classical image generation setting (unconditional FFHQ). It shows that the residual connections alter the generation process significantly, leading to abstract artifacts and the loss of coherence expected in a facial dataset: (top) with modification and (bottom) without modification.
  • ...and 13 more figures