Table of Contents
Fetching ...

Designing Parameter and Compute Efficient Diffusion Transformers using Distillation

Vignesh Sundaresha

TL;DR

This work addresses the challenge of deploying diffusion transformers (DiTs) on edge devices by applying knowledge distillation to create parameter- and compute-efficient DiTs (DiT-Nano). It develops design principles for sizing DiTs (depth, width, and heads) and introduces two distillation schemes, Teaching Assistant (TA) and Multi-In-One (MI1), with a practical emphasis on one-step diffusion and offline teacher signals. Empirical results on CIFAR-10 demonstrate that LPIPS-based GET distillation yields strong performance, with a favorable trade-off between model size, image quality (FID), and latency on edge hardware, outperforming a SOTA diffusion-distillation baseline in several metrics. The findings offer actionable guidelines for edge-ready diffusion models and point to future work on analytic justifications and broader design-space exploration for real-world applications.

Abstract

Diffusion Transformers (DiTs) with billions of model parameters form the backbone of popular image and video generation models like DALL.E, Stable-Diffusion and SORA. Though these models are necessary in many low-latency applications like Augmented/Virtual Reality, they cannot be deployed on resource-constrained Edge devices (like Apple Vision Pro or Meta Ray-Ban glasses) due to their huge computational complexity. To overcome this, we turn to knowledge distillation and perform a thorough design-space exploration to achieve the best DiT for a given parameter size. In particular, we provide principles for how to choose design knobs such as depth, width, attention heads and distillation setup for a DiT. During the process, a three-way trade-off emerges between model performance, size and speed that is crucial for Edge implementation of diffusion. We also propose two distillation approaches - Teaching Assistant (TA) method and Multi-In-One (MI1) method - to perform feature distillation in the DiT context. Unlike existing solutions, we demonstrate and benchmark the efficacy of our approaches on practical Edge devices such as NVIDIA Jetson Orin Nano.

Designing Parameter and Compute Efficient Diffusion Transformers using Distillation

TL;DR

This work addresses the challenge of deploying diffusion transformers (DiTs) on edge devices by applying knowledge distillation to create parameter- and compute-efficient DiTs (DiT-Nano). It develops design principles for sizing DiTs (depth, width, and heads) and introduces two distillation schemes, Teaching Assistant (TA) and Multi-In-One (MI1), with a practical emphasis on one-step diffusion and offline teacher signals. Empirical results on CIFAR-10 demonstrate that LPIPS-based GET distillation yields strong performance, with a favorable trade-off between model size, image quality (FID), and latency on edge hardware, outperforming a SOTA diffusion-distillation baseline in several metrics. The findings offer actionable guidelines for edge-ready diffusion models and point to future work on analytic justifications and broader design-space exploration for real-world applications.

Abstract

Diffusion Transformers (DiTs) with billions of model parameters form the backbone of popular image and video generation models like DALL.E, Stable-Diffusion and SORA. Though these models are necessary in many low-latency applications like Augmented/Virtual Reality, they cannot be deployed on resource-constrained Edge devices (like Apple Vision Pro or Meta Ray-Ban glasses) due to their huge computational complexity. To overcome this, we turn to knowledge distillation and perform a thorough design-space exploration to achieve the best DiT for a given parameter size. In particular, we provide principles for how to choose design knobs such as depth, width, attention heads and distillation setup for a DiT. During the process, a three-way trade-off emerges between model performance, size and speed that is crucial for Edge implementation of diffusion. We also propose two distillation approaches - Teaching Assistant (TA) method and Multi-In-One (MI1) method - to perform feature distillation in the DiT context. Unlike existing solutions, we demonstrate and benchmark the efficacy of our approaches on practical Edge devices such as NVIDIA Jetson Orin Nano.

Paper Structure

This paper contains 18 sections, 7 equations, 4 figures, 6 tables.

Figures (4)

  • Figure 1: Design space exploration of diffusion distillation.
  • Figure 2: (Left) Baseline approach which performs regular knowledge distillation using offline teacher geng2024one. (Center) Teaching Assistant (TA) which performs layer-wise distillation using an online teaching assistant and an offline teacher. (Right) Multi-In-One (MI1) approach which maps multiple diffusion steps into a single step. This is done by mapping different diffusion timesteps to different layers of a DiT. The training targets are obtained using forward diffusion of the probability flow ODE. See \ref{['app: algorithms']} for implementation details of TA and MI1.
  • Figure 3: (a) shows the impact of depth on FID and no. of parameters. It can be observed that the no. of parameters increases linearly with the increase in depth. (b) shows the impact of width on FID and no. of parameters. We see that the no. of parameters increases quadratically with increase in width. Both (a) & (b) show diminishing returns if increased independently and the performance (FID) does not scale in the same proportion as the increase in parameters, especially for width. (c) impact of no. of attention heads on FID. When the depth ($d$) or width ($w$) or attention heads ($h$) is not changing in any of these plots, it is assumed to be $d=6$, $w=128$ and $h=4$ respectively.
  • Figure 4: (a) Impact of number of parameters on FID when comparing deep DiTs, wide DiTs and the proposed sizing for DiTs. (b) Impact of parameters on the latency (on NVIDIA Jetson Nano). Wide DiTs have an observably worse performance while the deep DiTs have significantly higher latency due to serial processing. \ref{['tab: 0.42M params ablation']} shows the FID worsens going any deeper/wider. Our proposed sizing provides almost optimal FID at a much smaller latency. (c) The 3-dimensional trade-off highlighting the #params (memory) v/s FID (image quality) v/s latency (image frame rate) amidst optimal DiTs.