HSR-Enhanced Sparse Attention Acceleration
Bo Chen, Yingyu Liang, Zhizhou Sha, Zhenmei Shi, Zhao Song
TL;DR
This work targets the bottleneck of attention in long-context LLMs by exploiting sparsity with Half-Space Reporting (HSR). It shows how to identify and compute only the massively activated or nonzero entries for both Softmax and ReLU attention, achieving runtime reductions such as $O(mn^{4/5})$ for generation decoding and $O(mn^{1-1/\lfloor d/2\rfloor}+mn^{4/5})$ for prompt prefilling, with provably negligible Softmax error under mild assumptions. The approach leverages HSR to rapidly report high-impact index sets and provides detailed sparsity and error analyses alongside empirical validation on prominent models. The framework includes two concrete pipelines (generation decoding with fixed keys and prompt prefilling with dynamic keys) and rigorous runtime guarantees, bridging theory and potential practical speedups for long-context transformers. Overall, this work advances efficient long-context processing by marrying geometric data structures with attention sparsity, offering practical implications for latency and throughput in large-scale language models.
Abstract
Large Language Models (LLMs) have demonstrated remarkable capabilities across various applications, but their performance on long-context tasks is often limited by the computational complexity of attention mechanisms. We introduce a novel approach to accelerate attention computation in LLMs, particularly for long-context scenarios. We leverage the inherent sparsity within attention mechanisms, both in conventional Softmax attention and ReLU attention (with $\mathsf{ReLU}^α$ activation, $α\in \mathbb{N}_+$), to significantly reduce the running time complexity. Our method employs a Half-Space Reporting (HSR) data structure to identify non-zero or ``massively activated'' entries in the attention matrix. We present theoretical analyses for two key scenarios: generation decoding and prompt prefilling. Our approach achieves a running time of $O(mn^{4/5})$ significantly faster than the naive approach $O(mn)$ for generation decoding, where $n$ is the context length, $m$ is the query length, and $d$ is the hidden dimension. We can also reduce the running time for prompt prefilling from $O(mn)$ to $O(mn^{1 - 1 / \lfloor d/2\rfloor} + mn^{4/5})$. Our method introduces only provably negligible error for Softmax attention. This work represents a significant step towards enabling efficient long-context processing in LLMs.
