Table of Contents
Fetching ...

Scaling Attention via Feature Sparsity

Yan Xie, Tiansheng Wen, Tangda Huang, Bo Chen, Chenyu You, Stefanie Jegelka, Yifei Wang

Abstract

Scaling Transformers to ultra-long contexts is bottlenecked by the $O(n^2 d)$ cost of self-attention. Existing methods reduce this cost along the sequence axis through local windows, kernel approximations, or token-level sparsity, but these approaches consistently degrade accuracy. In this paper, we instead explore an orthogonal axis: feature sparsity. We propose Sparse Feature Attention (SFA), where queries and keys are represented as $k$-sparse codes that preserve high-dimensional expressivity while reducing the cost of attention from $Θ(n^2 d)$ to $Θ(n^2 k^2/d)$. To make this efficient at scale, we introduce FlashSFA, an IO-aware kernel that extends FlashAttention to operate directly on sparse overlaps without materializing dense score matrices. Across GPT-2 and Qwen3 pretraining, SFA matches dense baselines while improving speed by up to $2.5\times$ and reducing FLOPs and KV-cache by nearly 50\%. On synthetic and downstream benchmarks, SFA preserves retrieval accuracy and robustness at long contexts, outperforming short-embedding baselines that collapse feature diversity. These results establish feature-level sparsity as a complementary and underexplored axis for efficient attention, enabling Transformers to scale to orders-of-magnitude longer contexts with minimal quality loss. Code is available at https://github.com/YannX1e/Sparse-Feature-Attention.

Scaling Attention via Feature Sparsity

Abstract

Scaling Transformers to ultra-long contexts is bottlenecked by the cost of self-attention. Existing methods reduce this cost along the sequence axis through local windows, kernel approximations, or token-level sparsity, but these approaches consistently degrade accuracy. In this paper, we instead explore an orthogonal axis: feature sparsity. We propose Sparse Feature Attention (SFA), where queries and keys are represented as -sparse codes that preserve high-dimensional expressivity while reducing the cost of attention from to . To make this efficient at scale, we introduce FlashSFA, an IO-aware kernel that extends FlashAttention to operate directly on sparse overlaps without materializing dense score matrices. Across GPT-2 and Qwen3 pretraining, SFA matches dense baselines while improving speed by up to and reducing FLOPs and KV-cache by nearly 50\%. On synthetic and downstream benchmarks, SFA preserves retrieval accuracy and robustness at long contexts, outperforming short-embedding baselines that collapse feature diversity. These results establish feature-level sparsity as a complementary and underexplored axis for efficient attention, enabling Transformers to scale to orders-of-magnitude longer contexts with minimal quality loss. Code is available at https://github.com/YannX1e/Sparse-Feature-Attention.
Paper Structure (47 sections, 12 equations, 11 figures, 12 tables, 1 algorithm)

This paper contains 47 sections, 12 equations, 11 figures, 12 tables, 1 algorithm.

Figures (11)

  • Figure 1: Overview of our proposed method. (a) Trade-off between performance and speed. Compared to directly reducing dimensionality with short embeddings, our method achieves a more favorable balance, delivering a 259% speedup over the original dimensionality while improving performance by 21.4% relative to the short-embedding baseline. (b) Computational and memory efficiency comparison. Our method reduces KV-cache memory usage by 41% and FLOPs by 49%.
  • Figure 2: Three paradigms of attention.Left: Standard attention computes all $N \times N$ query–key interactions in the full feature dimension $d$. Middle: Sparse attention reduces cost by selecting, for each query $i$, a small subset of keys $\Omega_i$ and masking the remaining logits before softmax, but each retained interaction still spans all $d$ features. Right: Sparse Feature Attention (ours) keeps all tokens but sparsifies along the feature axis by selecting the top-$k$ channels in $Q$ and $K$ ($\tilde{Q}=\mathrm{Topk}_k(Q), \tilde{K}=\mathrm{Topk}_k(K)$). Attention is then computed only over overlapping selected features with sparse matrix multiplication. This shifts sparsity from the token axis ($N \times N$) to the feature axis, achieving efficiency while preserving token coverage.
  • Figure 3: Latency vs. feature sparsity. Latency Comparison of dense attention and SFA (ours) at different modular levels in Transformers under 16k context length. Higher sparsity brings substantial decrease in latency.
  • Figure 4: Latency vs. feature sparsity with various config. Latency Comparison of dense attention and SFA (ours) at different head dimensions and context lengths. Notably, the latency of SFA can be much lower than dense attention under high dimension per head and long context, e.g., Figure \ref{['fig:sparsity_scaling_latency_256_65k']}.
  • Figure 5: Scaling dense attention and SFA with context length. SFA can consistently reduce both the computatin cost and KV cache size by a constant factor of at least $2$.
  • ...and 6 more figures