Table of Contents
Fetching ...

TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training

Felix Krause, Timy Phan, Ming Gui, Stefan Andreas Baumann, Vincent Tao Hu, Björn Ommer

TL;DR

Diffusion models are hindered by high training costs and slow convergence. TREAD introduces a token routing mechanism that temporarily transports tokens from early to deeper layers during training, preserving all information and avoiding architectural changes. The approach yields substantial speedups (up to 37x) and improved generative metrics (e.g., FID improvements) on ImageNet-256, and scales to higher resolutions and text-to-image tasks, while remaining compatible with other efficiency techniques like representation distillation. These results suggest that architecture-agnostic token routing can dramatically accelerate diffusion training without sacrificing performance, broadening access and accelerating research in diffusion-based generative modeling.

Abstract

Diffusion models have emerged as the mainstream approach for visual generation. However, these models typically suffer from sample inefficiency and high training costs. Consequently, methods for efficient finetuning, inference and personalization were quickly adopted by the community. However, training these models in the first place remains very costly. While several recent approaches - including masking, distillation, and architectural modifications - have been proposed to improve training efficiency, each of these methods comes with a tradeoff: they achieve enhanced performance at the expense of increased computational cost or vice versa. In contrast, this work aims to improve training efficiency as well as generative performance at the same time through routes that act as a transport mechanism for randomly selected tokens from early layers to deeper layers of the model. Our method is not limited to the common transformer-based model - it can also be applied to state-space models and achieves this without architectural modifications or additional parameters. Finally, we show that TREAD reduces computational cost and simultaneously boosts model performance on the standard ImageNet-256 benchmark in class-conditional synthesis. Both of these benefits multiply to a convergence speedup of 14x at 400K training iterations compared to DiT and 37x compared to the best benchmark performance of DiT at 7M training iterations. Furthermore, we achieve a competitive FID of 2.09 in a guided and 3.93 in an unguided setting, which improves upon the DiT, without architectural changes.

TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training

TL;DR

Diffusion models are hindered by high training costs and slow convergence. TREAD introduces a token routing mechanism that temporarily transports tokens from early to deeper layers during training, preserving all information and avoiding architectural changes. The approach yields substantial speedups (up to 37x) and improved generative metrics (e.g., FID improvements) on ImageNet-256, and scales to higher resolutions and text-to-image tasks, while remaining compatible with other efficiency techniques like representation distillation. These results suggest that architecture-agnostic token routing can dramatically accelerate diffusion training without sacrificing performance, broadening access and accelerating research in diffusion-based generative modeling.

Abstract

Diffusion models have emerged as the mainstream approach for visual generation. However, these models typically suffer from sample inefficiency and high training costs. Consequently, methods for efficient finetuning, inference and personalization were quickly adopted by the community. However, training these models in the first place remains very costly. While several recent approaches - including masking, distillation, and architectural modifications - have been proposed to improve training efficiency, each of these methods comes with a tradeoff: they achieve enhanced performance at the expense of increased computational cost or vice versa. In contrast, this work aims to improve training efficiency as well as generative performance at the same time through routes that act as a transport mechanism for randomly selected tokens from early layers to deeper layers of the model. Our method is not limited to the common transformer-based model - it can also be applied to state-space models and achieves this without architectural modifications or additional parameters. Finally, we show that TREAD reduces computational cost and simultaneously boosts model performance on the standard ImageNet-256 benchmark in class-conditional synthesis. Both of these benefits multiply to a convergence speedup of 14x at 400K training iterations compared to DiT and 37x compared to the best benchmark performance of DiT at 7M training iterations. Furthermore, we achieve a competitive FID of 2.09 in a guided and 3.93 in an unguided setting, which improves upon the DiT, without architectural changes.
Paper Structure (33 sections, 7 equations, 15 figures, 9 tables)

This paper contains 33 sections, 7 equations, 15 figures, 9 tables.

Figures (15)

  • Figure 1: We introduce TREAD, a training strategy that enables substantially more efficient training of token-based diffusion backbones. Applied to the standard backbone DiT dit_peebles2022scalable, we achieve a 14/37$\times$ training speed increase w.r.t. unguided FID while also converging to better generation quality.
  • Figure 2: Selected samples from ImageNet-256 generated with a $\text{DiT-XL/2}_{+ \textbf{TREAD} \text{ }}$ using a guidance weight of $\omega=3.5$.
  • Figure 3: TREAD: Our method for efficient diffusion training. In a) the standard training and inference strategy is shown where all tokens are processed by all layers of the network. TREAD enhances training efficiency by routing tokens around certain layers by reducing computational load and it preserves information which is shown in b). Since TREAD is used only during training, the standard setting shown in a) is used for inference.
  • Figure 4: Consecutive layers have highly similar output. The effects of the routing mechanism are evident in the cosine similarities between layers. For $\textbf{r}_{3\rightarrow8}$, $L_2$ exhibits high similarity with the routed layers. This is interpreted as an adaptation of $L_2$ to $\textbf{r}_{3\rightarrow8}$.
  • Figure 5: TREAD shows good performance at low compute cost. We demonstrate a strictly better performance-cost trade-off than all other presented methods, including DiT+REPA yu2024repa. For methods with ($^*$), we assume an identical iteration speed to the DiT. This is advantageous for our competitors as those utilize additional parameters or entire pre-trained vision encoders to aid their diffusion model, effectively decreasing their iteration speed.
  • ...and 10 more figures