Table of Contents
Fetching ...

Physics Informed Distillation for Diffusion Models

Joshua Tian Jin Tee, Kang Zhang, Hee Suk Yoon, Dhananjaya Nagaraja Gowda, Chanwoo Kim, Chang D. Yoo

TL;DR

The paper tackles the slow sampling of diffusion models by reframing the teacher as a probability-flow ODE and training a student to learn the trajectory $\mathbf{x}_{\theta}(\mathbf{z}, t)$ in a physics-informed manner. By adopting a PINN-inspired residual loss, a stable boundary-preserving parametrization, and numerical differentiation, PID enables fast, single-step generation without synthetic data, and provides theoretical bounds linking discretization to trajectory accuracy. Empirical results on CIFAR-10 and ImageNet 64x64 show PID achieving competitive FID/IS with single-step sampling, albeit with higher training cost than some data-intensive distillation methods; the approach excels in not requiring synthetic data or heavy hyperparameter tuning. Overall, PID offers a practical, data-free distillation pathway for diffusion models with predictable behavior across discretization settings and a transparent training objective grounded in the underlying ODE dynamics.

Abstract

Diffusion models have recently emerged as a potent tool in generative modeling. However, their inherent iterative nature often results in sluggish image generation due to the requirement for multiple model evaluations. Recent progress has unveiled the intrinsic link between diffusion models and Probability Flow Ordinary Differential Equations (ODEs), thus enabling us to conceptualize diffusion models as ODE systems. Simultaneously, Physics Informed Neural Networks (PINNs) have substantiated their effectiveness in solving intricate differential equations through implicit modeling of their solutions. Building upon these foundational insights, we introduce Physics Informed Distillation (PID), which employs a student model to represent the solution of the ODE system corresponding to the teacher diffusion model, akin to the principles employed in PINNs. Through experiments on CIFAR 10 and ImageNet 64x64, we observe that PID achieves performance comparable to recent distillation methods. Notably, it demonstrates predictable trends concerning method-specific hyperparameters and eliminates the need for synthetic dataset generation during the distillation process. Both of which contribute to its easy-to-use nature as a distillation approach for Diffusion Models. Our code and pre-trained checkpoint are publicly available at: https://github.com/pantheon5100/pid_diffusion.git.

Physics Informed Distillation for Diffusion Models

TL;DR

The paper tackles the slow sampling of diffusion models by reframing the teacher as a probability-flow ODE and training a student to learn the trajectory in a physics-informed manner. By adopting a PINN-inspired residual loss, a stable boundary-preserving parametrization, and numerical differentiation, PID enables fast, single-step generation without synthetic data, and provides theoretical bounds linking discretization to trajectory accuracy. Empirical results on CIFAR-10 and ImageNet 64x64 show PID achieving competitive FID/IS with single-step sampling, albeit with higher training cost than some data-intensive distillation methods; the approach excels in not requiring synthetic data or heavy hyperparameter tuning. Overall, PID offers a practical, data-free distillation pathway for diffusion models with predictable behavior across discretization settings and a transparent training objective grounded in the underlying ODE dynamics.

Abstract

Diffusion models have recently emerged as a potent tool in generative modeling. However, their inherent iterative nature often results in sluggish image generation due to the requirement for multiple model evaluations. Recent progress has unveiled the intrinsic link between diffusion models and Probability Flow Ordinary Differential Equations (ODEs), thus enabling us to conceptualize diffusion models as ODE systems. Simultaneously, Physics Informed Neural Networks (PINNs) have substantiated their effectiveness in solving intricate differential equations through implicit modeling of their solutions. Building upon these foundational insights, we introduce Physics Informed Distillation (PID), which employs a student model to represent the solution of the ODE system corresponding to the teacher diffusion model, akin to the principles employed in PINNs. Through experiments on CIFAR 10 and ImageNet 64x64, we observe that PID achieves performance comparable to recent distillation methods. Notably, it demonstrates predictable trends concerning method-specific hyperparameters and eliminates the need for synthetic dataset generation during the distillation process. Both of which contribute to its easy-to-use nature as a distillation approach for Diffusion Models. Our code and pre-trained checkpoint are publicly available at: https://github.com/pantheon5100/pid_diffusion.git.

Paper Structure

This paper contains 31 sections, 1 theorem, 21 equations, 14 figures, 7 tables, 1 algorithm.

Key Result

Lemma 1

Assuming $D_{\phi}(\mathbf{x},t)$ is Lipchitz continuous with respect to $\mathbf{x}$, if $\mathcal{L}_\text{PID}$ = 0, $||\mathbf{x}_{\theta}(\mathbf{z},t)-\mathbf{x}(\mathbf{z},t)||_2\leq \mathcal{O}(\Delta t)$, where $\Delta t= \max_{i\in[0,N-1]}|t_{i+1}-t_i|$.

Figures (14)

  • Figure 1: An overview of the proposed method, which involves training a model $\mathbf{x}_{\theta}(\mathbf{z}, \cdot )$ to approximate the true trajectory $\mathbf{x}(\mathbf{z}, \cdot )$.
  • Figure 2: Conditional image generation comparison on ImageNet 64$\times$64 for the same seed with the same class label "Siberian husky". Left panel: random samples generated by teacher model EDM. Right panel: generated by student PID model.
  • Figure 3: Comparison between automatic differentiation and numerical differentiation on CIFAR-10 dataset.
  • Figure 4: Comparison between student model weights random initialized and initialized with teacher model weights on CIFAR-10 dataset.
  • Figure 5: The impact of stop gradient in the PID training on CIFAR-10 dataset.
  • ...and 9 more figures

Theorems & Definitions (2)

  • Lemma 1
  • proof