How Sparse Attention Approximates Exact Attention? Your Attention is Naturally $n^C$-Sparse
Yichuan Deng, Zhao Song, Jing Xiong, Chiwun Yang
TL;DR
This work provides a theoretical framework showing that standard attention is naturally sparse, with the effective number of large entries per row growing only as a sublinear function of input length. It derives concentration bounds for the underlying quantities, introduces the notion of attention collapse, and proves that a sparsity window of size $k=\Omega(n^{C})$ (for any constant $C\in(0,1)$) suffices to approximate exact attention with vanishing error, while $k=o(\\log n)$ is insufficient. It then advocates a dynamic Top-$k$ strategy, $k=\alpha n^{C}$, over fixed windows to achieve better accuracy-efficiency trade-offs and validates these insights with empirical simulations and long-context benchmarks. The findings provide concrete guidance for designing sub-quadratic attention mechanisms and understanding when sparse attention can be reliable in practice. Overall, the results offer theoretical justification for adaptive sparsity in attention and suggest robust directions for sparse transformer architectures in long-context settings.
Abstract
Sparse Attention is a technique that approximates standard attention computation with sub-quadratic complexity. This is achieved by selectively ignoring smaller entries in the attention matrix during the softmax function computation. Variations of this technique, such as pruning KV cache, sparsity-based fast attention, and Sparse Transformer, have been extensively utilized for efficient Large Language Models (LLMs) deployment. Despite its widespread use, a theoretical understanding of the conditions under which sparse attention performs on par with traditional attention remains elusive. This work aims to $\textbf{bridge this gap by examining the inherent sparsity of standard attention processes}$. Our theoretical framework reveals several brand-new key insights: $\bullet$ Attention is $n^{C}$-sparse, implying that considering only the largest $Ω(n^{C})$ entries out of all $n$ entries is sufficient for sparse attention to approximate the exact attention matrix with decreasing loss. Here, $n$ represents the input length and $C \in (0, 1)$ is a constant. $\bullet$ Stable $o(\log(n))$-sparse attention, which approximates attention computation with $\log(n)$ or fewer entries, may not be feasible since the error will persist at a minimum of $O(1)$. $\bullet$ An adaptive strategy ($α\cdot n^C, α\in \mathbb{R}$) for the window size of efficient attention methods rather than a fixed one is guaranteed to perform more accurately and efficiently in a task for inference on flexible context lengths.
