Table of Contents
Fetching ...

Toward Theoretical Insights into Diffusion Trajectory Distillation via Operator Merging

Weiguo Gao, Ming Li

TL;DR

The paper provides a principled, theory-backed view of diffusion trajectory distillation by treating each denoising step as a linear operator and framing the compression of a full teacher trajectory as an operator-merging problem under signal shrinkage. It introduces a dynamic-programming algorithm to compute an optimal merging plan that minimizes the Wasserstein-2 distance to a surrogate composite operator, and proves a sharp phase transition in the optimal strategy as data variance $\lambda$ varies, with sequential BOOT optimal for $\lambda \le 1$ and vanilla distillation optimal for $\lambda \gg 1$. The work unifies several canonical distillation methods and explains their empirical effectiveness as extremes of a broader merging landscape, supported by DP experiments on synthetic data and real latent spaces (e.g., CelebA MSSIMVAE). Practically, it offers principled guidance for selecting and designing distillation strategies based on data covariance, optimization budgets, and desired speedups in diffusion sampling.

Abstract

Diffusion trajectory distillation methods aim to accelerate sampling in diffusion models, which produce high-quality outputs but suffer from slow sampling speeds. These methods train a student model to approximate the multi-step denoising process of a pretrained teacher model in a single step, enabling one-shot generation. However, theoretical insights into the trade-off between different distillation strategies and generative quality remain limited, complicating their optimization and selection. In this work, we take a first step toward addressing this gap. Specifically, we reinterpret trajectory distillation as an operator merging problem in the linear regime, where each step of the teacher model is represented as a linear operator acting on noisy data. These operators admit a clear geometric interpretation as projections and rescalings corresponding to the noise schedule. During merging, signal shrinkage occurs as a convex combination of operators, arising from both discretization and limited optimization time of the student model. We propose a dynamic programming algorithm to compute the optimal merging strategy that maximally preserves signal fidelity. Additionally, we demonstrate the existence of a sharp phase transition in the optimal strategy, governed by data covariance structures. Our findings enhance the theoretical understanding of diffusion trajectory distillation and offer practical insights for improving distillation strategies.

Toward Theoretical Insights into Diffusion Trajectory Distillation via Operator Merging

TL;DR

The paper provides a principled, theory-backed view of diffusion trajectory distillation by treating each denoising step as a linear operator and framing the compression of a full teacher trajectory as an operator-merging problem under signal shrinkage. It introduces a dynamic-programming algorithm to compute an optimal merging plan that minimizes the Wasserstein-2 distance to a surrogate composite operator, and proves a sharp phase transition in the optimal strategy as data variance varies, with sequential BOOT optimal for and vanilla distillation optimal for . The work unifies several canonical distillation methods and explains their empirical effectiveness as extremes of a broader merging landscape, supported by DP experiments on synthetic data and real latent spaces (e.g., CelebA MSSIMVAE). Practically, it offers principled guidance for selecting and designing distillation strategies based on data covariance, optimization budgets, and desired speedups in diffusion sampling.

Abstract

Diffusion trajectory distillation methods aim to accelerate sampling in diffusion models, which produce high-quality outputs but suffer from slow sampling speeds. These methods train a student model to approximate the multi-step denoising process of a pretrained teacher model in a single step, enabling one-shot generation. However, theoretical insights into the trade-off between different distillation strategies and generative quality remain limited, complicating their optimization and selection. In this work, we take a first step toward addressing this gap. Specifically, we reinterpret trajectory distillation as an operator merging problem in the linear regime, where each step of the teacher model is represented as a linear operator acting on noisy data. These operators admit a clear geometric interpretation as projections and rescalings corresponding to the noise schedule. During merging, signal shrinkage occurs as a convex combination of operators, arising from both discretization and limited optimization time of the student model. We propose a dynamic programming algorithm to compute the optimal merging strategy that maximally preserves signal fidelity. Additionally, we demonstrate the existence of a sharp phase transition in the optimal strategy, governed by data covariance structures. Our findings enhance the theoretical understanding of diffusion trajectory distillation and offer practical insights for improving distillation strategies.

Paper Structure

This paper contains 49 sections, 13 theorems, 97 equations, 19 figures, 4 tables, 1 algorithm.

Key Result

Proposition 3.1

Assume the real data distribution $p_0$ is given by as:real_data_distribution, the forward process follows eq:forward_process, and the denoising estimator minimizes the expected denoising loss in eq:denoising_loss_function. Then the optimal denoising estimator for $\bm{z}_t$ is

Figures (19)

  • Figure 1: Geometric interpretation of the signal-noise vectors $\bm{v}_t^i = (\alpha_t\sqrt{\lambda_i}, \sigma_t)$, which lies on an ellipse. \ref{['eq:single_operator_formula']} corresponds to projecting $\bm{v}_{t-1}^i$ onto $\bm{v}_t^i$ and computing its ratio with $\|\bm{v}_t^i\|_2$. Left: full sequence of vectors ($T=8$). Middle: projection when $\lambda_i < 1$. Right: projection when $\lambda_i > 1$.
  • Figure 2: First row: Error gap between four canonical strategies and the DP-optimal solution as a function of $\lambda$, with $T=32$ and $s=6.4$. As predicted by \ref{['thm:vanilla_distillation_optimality']} and \ref{['thm:sequential_boot_optimality']}, sequential BOOT achieves optimality when $\lambda \leq 1$, while vanilla trajectory distillation becomes optimal for sufficiently large $\lambda>2$. Second row: Visualization of the DP-optimal merge plans at $\lambda=1.08$ (left) and $\lambda=2$ (right). Each arc represents a merge operation, with lighter colors indicating earlier merges and darker colors corresponding to later merges. For additional results under varying $T$, $s$, and covariance matrices $\bm{\Lambda}$, please refer to \ref{['app:additional_experimental_results_on_dynamic_programming']}.
  • Figure 3: Conceptual illustration of vanilla distillation. The entire teacher trajectory is being merged into a single step.
  • Figure 4: Conceptual illustration of progressive distillation. The student model is trained with a curriculum that progressively merges pairs of teacher steps into a single student update, thereby halving the number of sampling steps required at each stage.
  • Figure 5: Conceptual illustration of sequential consistency distillation. The student model is trained to map each $\bm{z}_t$ directly to the clean sample $\bm{z}_0$, progressively learning to compress longer teacher trajectories into a single step.
  • ...and 14 more figures

Theorems & Definitions (21)

  • Definition 2.1: Composite operator of multiple steps
  • Proposition 3.1: Optimal denoising estimator is linear
  • Corollary 3.1: Relationship between consecutive steps is linear
  • Proposition 3.2: Teacher model contracts the covariance
  • Proposition 3.3: Student interpolation under gradient flow
  • Definition 4.1: Surrogate composite operator
  • Theorem 4.1: Optimality of dynamic programming merge
  • Theorem 4.2: Vanilla trajectory distillation is optimal when (λ ≫ 1)
  • Theorem 4.3: Sequential BOOT is optimal when (λ ≤ 1)
  • Proposition C.1: Optimal denoising estimator is linear
  • ...and 11 more