Table of Contents
Fetching ...

Shiva-DiT: Residual-Based Differentiable Top-$k$ Selection for Efficient Diffusion Transformers

Jiaji Zhang, Hailiang Zhao, Guoxuan Zhu, Ruichao Sun, Jiaju Wu, Xinkui Zhao, Hanlin Tang, Weiyi Lu, Kan Liu, Tao Lan, Lin Qu, Shuiguang Deng

TL;DR

Shiva-DiT addresses the prohibitive compute of Diffusion Transformers by enforcing a fixed token budget with a differentiable top-$k$ mechanism that preserves end-to-end learning. The approach combines Residual-Based Differentiable Sorting, a Context-Aware Token Importance Router, and an Adaptive Ratio Policy to dynamically allocate computation while maintaining hardware-friendly static shapes. Key contributions include a gradient estimator that propagates through the hard token selection to learn $k$, a lightweight router that captures spatiotemporal token importance, and a stabilized budget constraint with stratified sampling, achieving a new Pareto frontier (e.g., ~$1.54\times$ wall-clock speedup on SD3.5) with preserved fidelity. The results demonstrate practical speedups for large diffusion backbones on high-resolution inputs, enabling more feasible real-time or resource-constrained deployment of diffusion-based generative models.

Abstract

Diffusion Transformers (DiTs) incur prohibitive computational costs due to the quadratic scaling of self-attention. Existing pruning methods fail to simultaneously satisfy differentiability, efficiency, and the strict static budgets required for hardware overhead. To address this, we propose Shiva-DiT, which effectively reconciles these conflicting requirements via Residual-Based Differentiable Top-$k$ Selection. By leveraging a residual-aware straight-through estimator, our method enforces deterministic token counts for static compilation while preserving end-to-end learnability through residual gradient estimation. Furthermore, we introduce a Context-Aware Router and Adaptive Ratio Policy to autonomously learn an adaptive pruning schedule. Experiments on mainstream models, including SD3.5, demonstrate that Shiva-DiT establishes a new Pareto frontier, achieving a 1.54$\times$ wall-clock speedup with superior fidelity compared to existing baselines, effectively eliminating ragged tensor overheads.

Shiva-DiT: Residual-Based Differentiable Top-$k$ Selection for Efficient Diffusion Transformers

TL;DR

Shiva-DiT addresses the prohibitive compute of Diffusion Transformers by enforcing a fixed token budget with a differentiable top- mechanism that preserves end-to-end learning. The approach combines Residual-Based Differentiable Sorting, a Context-Aware Token Importance Router, and an Adaptive Ratio Policy to dynamically allocate computation while maintaining hardware-friendly static shapes. Key contributions include a gradient estimator that propagates through the hard token selection to learn , a lightweight router that captures spatiotemporal token importance, and a stabilized budget constraint with stratified sampling, achieving a new Pareto frontier (e.g., ~ wall-clock speedup on SD3.5) with preserved fidelity. The results demonstrate practical speedups for large diffusion backbones on high-resolution inputs, enabling more feasible real-time or resource-constrained deployment of diffusion-based generative models.

Abstract

Diffusion Transformers (DiTs) incur prohibitive computational costs due to the quadratic scaling of self-attention. Existing pruning methods fail to simultaneously satisfy differentiability, efficiency, and the strict static budgets required for hardware overhead. To address this, we propose Shiva-DiT, which effectively reconciles these conflicting requirements via Residual-Based Differentiable Top- Selection. By leveraging a residual-aware straight-through estimator, our method enforces deterministic token counts for static compilation while preserving end-to-end learnability through residual gradient estimation. Furthermore, we introduce a Context-Aware Router and Adaptive Ratio Policy to autonomously learn an adaptive pruning schedule. Experiments on mainstream models, including SD3.5, demonstrate that Shiva-DiT establishes a new Pareto frontier, achieving a 1.54 wall-clock speedup with superior fidelity compared to existing baselines, effectively eliminating ragged tensor overheads.
Paper Structure (71 sections, 31 equations, 17 figures, 5 tables, 1 algorithm)

This paper contains 71 sections, 31 equations, 17 figures, 5 tables, 1 algorithm.

Figures (17)

  • Figure 1: Overview of Shiva-DiT. (a) Shiva simultaneously resolves the sparse learning trilemma: Differentiability, Efficiency, and Strict Budget. (b) We inject a lightweight Importance Router and Adaptive Ratio Policy into the frozen backbone. Utilizing differentiable sorting, Shiva ensures learnability while maintaining static tensor shapes for hardware efficiency.
  • Figure 2: Qualitative Comparison. Prompt: God Zeus wearing a golden oak leaves crown on his head, grey beard. By directing token reduction to the background, Shiva-DiT preserves high-fidelity subject details (e.g., beard, weaving) comparable to Vanilla, whereas competitors degrade the primary figure.
  • Figure 3: Efficiency-Quality Trade-off. Shiva-DiT pushes the Pareto frontier towards the bottom-left, achieving a superior balance between inference latency and FID compared to baselines.
  • Figure 4: Adaptive Policy Analysis. (a) The adaptive policy achieves consistently lower validation MSE, indicating superior convergence. (b) The heatmap illustrates the pruning ratios across network layers ($y$-axis) and diffusion timesteps ($x$-axis), revealing a clear "structure-first, detail-later" preference.
  • Figure 5: Visualization of Learned Masks (Group-12). The pairwise router effectively identifies semantic regions (e.g., foreground objects) and texture-rich areas, demonstrating that locally shared policies can accurately capture token importance.
  • ...and 12 more figures