Table of Contents
Fetching ...

SSA: Sparse Sparse Attention by Aligning Full and Sparse Attention Outputs in Feature Space

Zhenyi Shen, Junru Lu, Lin Gui, Jiazheng Li, Yulan He, Di Yin, Xing Sun

TL;DR

The paper tackles the quadratic complexity of full self-attention in long-context LLMs by introducing Sparse Sparse Attention (SSA), a unified training framework that jointly trains sparse and full attention while enforcing bidirectional alignment at every layer. SSA preserves gradient flow for all tokens and explicitly aligns sparse-attention outputs with their full-attention counterparts, leading to higher inherent attention sparsity and strong performance under both sparse and full inference. Across language modeling and commonsense reasoning benchmarks, SSA achieves state-of-the-art results, improves long-context extrapolation, and enables smooth adaptation to varying sparsity budgets. The approach mitigates attention sink and demonstrates practical impact for scalable, compute-flexible LLMs. $L = \mathbb{E}_{\text{mode} \sim \{\text{full}, \text{sparse}\}} [L_{\text{mode}}] + \alpha L_{\text{alignment}}$ with $L_{\text{alignment}} = L_{\text{sparsity}} + L_{\text{commitment}}$ embodies this balance between accuracy and efficiency.

Abstract

The quadratic complexity of full attention limits efficient long-context processing in large language models (LLMs). Sparse attention mitigates this cost by restricting each query to attend to a subset of previous tokens; however, training-free approaches often lead to severe performance degradation. Native sparse-attention methods (e.g., NSA, MoBA) alleviate this issue, yet exhibit a critical paradox: they produce lower attention sparsity than full-attention models, despite aiming to approximate full attention, which may constrain their effectiveness. We attribute this paradox to gradient update deficiency: low-ranked key-value pairs excluded during sparse training receive neither forward contribution nor backward gradients, and thus never learn proper suppression. To overcome this limitation, we propose SSA (Sparse Sparse Attention), a unified training framework that considers both sparse and full attention and enforces bidirectional alignment at every layer. This design preserves gradient flow to all tokens while explicitly encouraging sparse-attention outputs to align with their full-attention counterparts, thereby promoting stronger sparsity. As a result, SSA achieves state-of-the-art performance under both sparse and full attention inference across multiple commonsense benchmarks. Furthermore, SSA enables models to adapt smoothly to varying sparsity budgets; performance improves consistently as more tokens are allowed to attend, supporting flexible compute-performance trade-offs at inference time. Finally, we show that native sparse-attention training surprisingly improves long-context extrapolation by mitigating the over-allocation of attention values in sink areas, with SSA demonstrating the strongest extrapolation capability.

SSA: Sparse Sparse Attention by Aligning Full and Sparse Attention Outputs in Feature Space

TL;DR

The paper tackles the quadratic complexity of full self-attention in long-context LLMs by introducing Sparse Sparse Attention (SSA), a unified training framework that jointly trains sparse and full attention while enforcing bidirectional alignment at every layer. SSA preserves gradient flow for all tokens and explicitly aligns sparse-attention outputs with their full-attention counterparts, leading to higher inherent attention sparsity and strong performance under both sparse and full inference. Across language modeling and commonsense reasoning benchmarks, SSA achieves state-of-the-art results, improves long-context extrapolation, and enables smooth adaptation to varying sparsity budgets. The approach mitigates attention sink and demonstrates practical impact for scalable, compute-flexible LLMs. with embodies this balance between accuracy and efficiency.

Abstract

The quadratic complexity of full attention limits efficient long-context processing in large language models (LLMs). Sparse attention mitigates this cost by restricting each query to attend to a subset of previous tokens; however, training-free approaches often lead to severe performance degradation. Native sparse-attention methods (e.g., NSA, MoBA) alleviate this issue, yet exhibit a critical paradox: they produce lower attention sparsity than full-attention models, despite aiming to approximate full attention, which may constrain their effectiveness. We attribute this paradox to gradient update deficiency: low-ranked key-value pairs excluded during sparse training receive neither forward contribution nor backward gradients, and thus never learn proper suppression. To overcome this limitation, we propose SSA (Sparse Sparse Attention), a unified training framework that considers both sparse and full attention and enforces bidirectional alignment at every layer. This design preserves gradient flow to all tokens while explicitly encouraging sparse-attention outputs to align with their full-attention counterparts, thereby promoting stronger sparsity. As a result, SSA achieves state-of-the-art performance under both sparse and full attention inference across multiple commonsense benchmarks. Furthermore, SSA enables models to adapt smoothly to varying sparsity budgets; performance improves consistently as more tokens are allowed to attend, supporting flexible compute-performance trade-offs at inference time. Finally, we show that native sparse-attention training surprisingly improves long-context extrapolation by mitigating the over-allocation of attention values in sink areas, with SSA demonstrating the strongest extrapolation capability.

Paper Structure

This paper contains 23 sections, 10 equations, 12 figures, 8 tables, 1 algorithm.

Figures (12)

  • Figure 1: Preliminary results using 300M-parameter models trained on 50B tokens. The FA model (full attention training) exhibits higher attention sparsity and lower attention entropy than the SA model (sparse attention training), and each performs best under its native inference mode. In contrast, SSA attains the highest attention sparsity and achieves the strongest performance under both full- and sparse-attention inference. (a) Perplexity on WikiText @8k; (b) Commonsense reasoning benchmark average scores; (c) Attention entropy. (d) Attention sparsity;
  • Figure 2: Illustration of the SSA training framework. At each iteration, the model has an equal probability of following either the Sparse Attention (SA) stream or the Full Attention (FA) stream. In the SA stream, the model learns sparse attention while aligning its output with a full-attention counterpart computed on the fly. Conversely, in the FA stream, the model learns full attention constrained by alignment with the corresponding sparse-attention output. For clarity, skip connections, normalization, and dropout layers are omitted from the figure.
  • Figure 3: Performance versus receptive-field size. SSA and FullAttn extrapolate well, consistently improving as more tokens become visible, whereas MoBA exhibits poor extrapolation.
  • Figure 4: (a) Perplexity across context lengths. “FullAttn’’ and “SparseAttn’’ in parentheses indicate full-attention and sparse-attention (receptive field = 256) inference, respectively. (b) Increasing the proportion of sparse-attention training in SSA improves length extrapolation. (c) SSA produces higher local-logit weights than MoBA and FullAttn. (d) FullAttn allocates substantial attention mass to tokens beyond 8k. Panel (a) uses 1B models, while panels (b–d) use 300M models, and all of them use full attention for inference.
  • Figure 5: Perplexity across different context lengths for NSA architectural ablations, where CMP denotes the compression module, SEL the selection module, and SWA the sliding-window module. NSA relies on the sliding-window component (SWA) to maintain PPL stability at long context lengths, whereas SSA achieves even better PPL stability without requiring a sliding-window mechanism.
  • ...and 7 more figures