Table of Contents
Fetching ...

Token Caching for Diffusion Transformer Acceleration

Jinming Lou, Wenyang Luo, Yufan Liu, Bing Li, Xinmiao Ding, Weiming Hu, Jiajiong Cao, Yuming Li, Chenguang Ma

TL;DR

Diffusion Transformers incur high computational costs due to attention and multi-step denoising. The authors propose TokenCache, a post-training acceleration method that prunes and caches intermediate tokens using a Cache Predictor, adaptive block selection, and a Two-Phase Round-Robin (TPRR) scheduling policy to balance speed and quality. The approach decomposes caching into three dimensions—token pruning, block targeting, and timesteps—enabling fine-grained acceleration while freezing DiT weights. Experiments on ImageNet with DiT and MDT show favorable quality-speed trade-offs, outperforming block-level baselines and achieving significant speedups. This work demonstrates the viability of token-level caching for diffusion transformers and suggests a fruitful direction for efficient generative modeling.

Abstract

Diffusion transformers have gained substantial interest in diffusion generative modeling due to their outstanding performance. However, their high computational cost, arising from the quadratic computational complexity of attention mechanisms and multi-step inference, presents a significant bottleneck. To address this challenge, we propose TokenCache, a novel post-training acceleration method that leverages the token-based multi-block architecture of transformers to reduce redundant computations among tokens across inference steps. TokenCache specifically addresses three critical questions in the context of diffusion transformers: (1) which tokens should be pruned to eliminate redundancy, (2) which blocks should be targeted for efficient pruning, and (3) at which time steps caching should be applied to balance speed and quality. In response to these challenges, TokenCache introduces a Cache Predictor that assigns importance scores to tokens, enabling selective pruning without compromising model performance. Furthermore, we propose an adaptive block selection strategy to focus on blocks with minimal impact on the network's output, along with a Two-Phase Round-Robin (TPRR) scheduling policy to optimize caching intervals throughout the denoising process. Experimental results across various models demonstrate that TokenCache achieves an effective trade-off between generation quality and inference speed for diffusion transformers. Our code will be publicly available.

Token Caching for Diffusion Transformer Acceleration

TL;DR

Diffusion Transformers incur high computational costs due to attention and multi-step denoising. The authors propose TokenCache, a post-training acceleration method that prunes and caches intermediate tokens using a Cache Predictor, adaptive block selection, and a Two-Phase Round-Robin (TPRR) scheduling policy to balance speed and quality. The approach decomposes caching into three dimensions—token pruning, block targeting, and timesteps—enabling fine-grained acceleration while freezing DiT weights. Experiments on ImageNet with DiT and MDT show favorable quality-speed trade-offs, outperforming block-level baselines and achieving significant speedups. This work demonstrates the viability of token-level caching for diffusion transformers and suggests a fruitful direction for efficient generative modeling.

Abstract

Diffusion transformers have gained substantial interest in diffusion generative modeling due to their outstanding performance. However, their high computational cost, arising from the quadratic computational complexity of attention mechanisms and multi-step inference, presents a significant bottleneck. To address this challenge, we propose TokenCache, a novel post-training acceleration method that leverages the token-based multi-block architecture of transformers to reduce redundant computations among tokens across inference steps. TokenCache specifically addresses three critical questions in the context of diffusion transformers: (1) which tokens should be pruned to eliminate redundancy, (2) which blocks should be targeted for efficient pruning, and (3) at which time steps caching should be applied to balance speed and quality. In response to these challenges, TokenCache introduces a Cache Predictor that assigns importance scores to tokens, enabling selective pruning without compromising model performance. Furthermore, we propose an adaptive block selection strategy to focus on blocks with minimal impact on the network's output, along with a Two-Phase Round-Robin (TPRR) scheduling policy to optimize caching intervals throughout the denoising process. Experimental results across various models demonstrate that TokenCache achieves an effective trade-off between generation quality and inference speed for diffusion transformers. Our code will be publicly available.
Paper Structure (25 sections, 5 equations, 7 figures, 11 tables)

This paper contains 25 sections, 5 equations, 7 figures, 11 tables.

Figures (7)

  • Figure 1: Examples of images generated by TokenCache-accelerated diffusion transformers. Left: DiT. Right: MDT. Our method can achieve similar visual quality as the original (non-accelerated) model with as high as 1.44$\times$ speedup.
  • Figure 2: Demonstration of token redundancy in diffusion transformers. (a) and (b) show the heatmaps of changes that each network block applies to each token. (c) and (d) plot the similarity of the output tokens from the same block across different timesteps. (e) visualizes our two-phase timestep schedule.
  • Figure 3: Framework of TokenCache. TokenCache decomposes the space of caching strategies into three "dimensions": 1) which tokens to prune, where we propose the Cache Predictor for estimating the importance of tokens for pruning and reusing the cached values; 2) which block to prune tokens, where we adaptive select the least important blocks given the importance of tokens; 3) which timesteps to perform pruning and caching, where we present the Two-Phase Round-Robin (TPRR) timestep schedule that interleave non-pruning I-steps and pruning P-steps in two phases. The pretrained DiT weights are frozen, and only the strategy is adapted to the inference configuration.
  • Figure 4: Illustration of our token pruning strategy via Cache Predictor. (a) Training, where caching and non-caching states of the tokens are superposed. (b) Evaluation, where grid-based pruning and reusing previous values is performed.
  • Figure 5: FID scores for different grid-level pruning ratios under varying target computational costs.
  • ...and 2 more figures