The Key to State Reduction in Linear Attention: A Rank-based Perspective
Philipp Nazari, T. Konstantin Rusch
TL;DR
This work analyzes why linear attention’s associative memory often operates at low effective rank and how this degrades retrieval under noise. It develops a rank-centric theory linking effective rank, rank utilization, and retrieval error, and then introduces a hardware-aware, axis-aligned pruning framework (including the DRRQR method) to reduce state size while preserving compatibility with existing depthwise convolutions and CUDA kernels. Empirically, pruning can remove about 50% of key and query channels with only modest perplexity increases and notable throughput gains, though recall-heavy tasks may still suffer without hybridizing with softmax attention. The results provide a principled path toward faster, memory-efficient linear-attention models and offer design guidance for future hybrid architectures that balance efficiency with retrieval performance.
Abstract
Linear attention offers a computationally efficient yet expressive alternative to softmax attention. However, recent empirical results indicate that the state of trained linear attention models often exhibits a low-rank structure, suggesting that these models underexploit their capacity in practice. To illuminate this phenomenon, we provide a theoretical analysis of the role of rank in linear attention, revealing that low effective rank can affect retrieval error by amplifying query noise. In addition to these theoretical insights, we conjecture that the low-rank states can be substantially reduced post-training with only minimal performance degradation, yielding faster and more memory-efficient models. To this end, we propose a novel hardware-aware approach that structurally prunes key and query matrices, reducing the state size while retaining compatibility with existing CUDA kernels. We adapt several existing pruning strategies to fit our framework and, building on our theoretical analysis, propose a novel structured pruning method based on a rank-revealing QR decomposition. Our empirical results, evaluated across models of varying sizes and on various downstream tasks, demonstrate the effectiveness of our state reduction framework. We highlight that our framework enables the removal of 50% of the query and key channels at only a marginal increase in perplexity. The code for this project can be found at https://github.com/camail-official/LinearAttentionPruning.
