Shiva-DiT: Residual-Based Differentiable Top-$k$ Selection for Efficient Diffusion Transformers
Jiaji Zhang, Hailiang Zhao, Guoxuan Zhu, Ruichao Sun, Jiaju Wu, Xinkui Zhao, Hanlin Tang, Weiyi Lu, Kan Liu, Tao Lan, Lin Qu, Shuiguang Deng
TL;DR
Shiva-DiT addresses the prohibitive compute of Diffusion Transformers by enforcing a fixed token budget with a differentiable top-$k$ mechanism that preserves end-to-end learning. The approach combines Residual-Based Differentiable Sorting, a Context-Aware Token Importance Router, and an Adaptive Ratio Policy to dynamically allocate computation while maintaining hardware-friendly static shapes. Key contributions include a gradient estimator that propagates through the hard token selection to learn $k$, a lightweight router that captures spatiotemporal token importance, and a stabilized budget constraint with stratified sampling, achieving a new Pareto frontier (e.g., ~$1.54\times$ wall-clock speedup on SD3.5) with preserved fidelity. The results demonstrate practical speedups for large diffusion backbones on high-resolution inputs, enabling more feasible real-time or resource-constrained deployment of diffusion-based generative models.
Abstract
Diffusion Transformers (DiTs) incur prohibitive computational costs due to the quadratic scaling of self-attention. Existing pruning methods fail to simultaneously satisfy differentiability, efficiency, and the strict static budgets required for hardware overhead. To address this, we propose Shiva-DiT, which effectively reconciles these conflicting requirements via Residual-Based Differentiable Top-$k$ Selection. By leveraging a residual-aware straight-through estimator, our method enforces deterministic token counts for static compilation while preserving end-to-end learnability through residual gradient estimation. Furthermore, we introduce a Context-Aware Router and Adaptive Ratio Policy to autonomously learn an adaptive pruning schedule. Experiments on mainstream models, including SD3.5, demonstrate that Shiva-DiT establishes a new Pareto frontier, achieving a 1.54$\times$ wall-clock speedup with superior fidelity compared to existing baselines, effectively eliminating ragged tensor overheads.
