SSA: Sparse Sparse Attention by Aligning Full and Sparse Attention Outputs in Feature Space
Zhenyi Shen, Junru Lu, Lin Gui, Jiazheng Li, Yulan He, Di Yin, Xing Sun
TL;DR
The paper tackles the quadratic complexity of full self-attention in long-context LLMs by introducing Sparse Sparse Attention (SSA), a unified training framework that jointly trains sparse and full attention while enforcing bidirectional alignment at every layer. SSA preserves gradient flow for all tokens and explicitly aligns sparse-attention outputs with their full-attention counterparts, leading to higher inherent attention sparsity and strong performance under both sparse and full inference. Across language modeling and commonsense reasoning benchmarks, SSA achieves state-of-the-art results, improves long-context extrapolation, and enables smooth adaptation to varying sparsity budgets. The approach mitigates attention sink and demonstrates practical impact for scalable, compute-flexible LLMs. $L = \mathbb{E}_{\text{mode} \sim \{\text{full}, \text{sparse}\}} [L_{\text{mode}}] + \alpha L_{\text{alignment}}$ with $L_{\text{alignment}} = L_{\text{sparsity}} + L_{\text{commitment}}$ embodies this balance between accuracy and efficiency.
Abstract
The quadratic complexity of full attention limits efficient long-context processing in large language models (LLMs). Sparse attention mitigates this cost by restricting each query to attend to a subset of previous tokens; however, training-free approaches often lead to severe performance degradation. Native sparse-attention methods (e.g., NSA, MoBA) alleviate this issue, yet exhibit a critical paradox: they produce lower attention sparsity than full-attention models, despite aiming to approximate full attention, which may constrain their effectiveness. We attribute this paradox to gradient update deficiency: low-ranked key-value pairs excluded during sparse training receive neither forward contribution nor backward gradients, and thus never learn proper suppression. To overcome this limitation, we propose SSA (Sparse Sparse Attention), a unified training framework that considers both sparse and full attention and enforces bidirectional alignment at every layer. This design preserves gradient flow to all tokens while explicitly encouraging sparse-attention outputs to align with their full-attention counterparts, thereby promoting stronger sparsity. As a result, SSA achieves state-of-the-art performance under both sparse and full attention inference across multiple commonsense benchmarks. Furthermore, SSA enables models to adapt smoothly to varying sparsity budgets; performance improves consistently as more tokens are allowed to attend, supporting flexible compute-performance trade-offs at inference time. Finally, we show that native sparse-attention training surprisingly improves long-context extrapolation by mitigating the over-allocation of attention values in sink areas, with SSA demonstrating the strongest extrapolation capability.
