Table of Contents
Fetching ...

SeerAttention-R: Sparse Attention Adaptation for Long Reasoning

Yizhao Gao, Shuming Guo, Shijie Cao, Yuqing Xia, Yu Cheng, Lei Wang, Lingxiao Ma, Yutao Sun, Tianzhu Ye, Li Dong, Hayden Kwok-Hay So, Yu Hua, Ting Cao, Fan Yang, Mao Yang

TL;DR

Long-context reasoning in autoregressive models is hampered by quadratic KV-cache costs. SeerAttention-R introduces a post-training, plug-in AttnGate that enables sparse decoding by learning shared sparsity within Grouped Query Attention, removing Q pooling for autoregressive decoding and leveraging a K-compression cache plus a fast block-sparse decoding kernel (TileLang). The approach achieves near-lossless reasoning with a 4K token budget on challenging benchmarks and delivers substantial hardware speedups (up to 9x over FlashAttention-3) at high sparsity. Training is lightweight (gate-only) and can be applied to multiple pretrained models, facilitating practical deployment for long-sequence reasoning tasks. The work also demonstrates strong kernel-level acceleration and provides thorough ablations, paving the way for end-to-end speedups and adaptive sparsity in future work.

Abstract

We introduce SeerAttention-R, a sparse attention framework specifically tailored for the long decoding of reasoning models. Extended from SeerAttention, SeerAttention-R retains the design of learning attention sparsity through a self-distilled gating mechanism, while removing query pooling to accommodate auto-regressive decoding. With a lightweight plug-in gating, SeerAttention-R is flexible and can be easily integrated into existing pretrained model without modifying the original parameters. We demonstrate that SeerAttention-R, trained on just 0.4B tokens, maintains near-lossless reasoning accuracy with 4K token budget in AIME benchmark under large sparse attention block sizes (64/128). Using TileLang, we develop a highly optimized sparse decoding kernel that achieves near-theoretical speedups of up to 9x over FlashAttention-3 on H100 GPU at 90% sparsity. Code is available at: https://github.com/microsoft/SeerAttention.

SeerAttention-R: Sparse Attention Adaptation for Long Reasoning

TL;DR

Long-context reasoning in autoregressive models is hampered by quadratic KV-cache costs. SeerAttention-R introduces a post-training, plug-in AttnGate that enables sparse decoding by learning shared sparsity within Grouped Query Attention, removing Q pooling for autoregressive decoding and leveraging a K-compression cache plus a fast block-sparse decoding kernel (TileLang). The approach achieves near-lossless reasoning with a 4K token budget on challenging benchmarks and delivers substantial hardware speedups (up to 9x over FlashAttention-3) at high sparsity. Training is lightweight (gate-only) and can be applied to multiple pretrained models, facilitating practical deployment for long-sequence reasoning tasks. The work also demonstrates strong kernel-level acceleration and provides thorough ablations, paving the way for end-to-end speedups and adaptive sparsity in future work.

Abstract

We introduce SeerAttention-R, a sparse attention framework specifically tailored for the long decoding of reasoning models. Extended from SeerAttention, SeerAttention-R retains the design of learning attention sparsity through a self-distilled gating mechanism, while removing query pooling to accommodate auto-regressive decoding. With a lightweight plug-in gating, SeerAttention-R is flexible and can be easily integrated into existing pretrained model without modifying the original parameters. We demonstrate that SeerAttention-R, trained on just 0.4B tokens, maintains near-lossless reasoning accuracy with 4K token budget in AIME benchmark under large sparse attention block sizes (64/128). Using TileLang, we develop a highly optimized sparse decoding kernel that achieves near-theoretical speedups of up to 9x over FlashAttention-3 on H100 GPU at 90% sparsity. Code is available at: https://github.com/microsoft/SeerAttention.

Paper Structure

This paper contains 36 sections, 1 equation, 9 figures, 2 tables.

Figures (9)

  • Figure 1: SeerAttention (Sparse Prefill) and SeerAttention-R (Sparse Decode). In SeerAttention-R, no sequence dimension compression/pooling operation is applied in Query (Q). Given that modern architectures predominantly use GQA, a linear layer projects the Q from its original number of heads down to the number of KV heads, enabling shared sparsity selection in a GQA group.
  • Figure 2: Training Diagram and Training Kernel of SeerAttention-R. (a) Self-distillation training of AttnGate in SeerAttention-R. It uses 1D maxpooled attention scores from original model as ground truth to train AttnGate. Query head reduction is not plotted in the diagram for simplicity. (2) Pseudo code of attention forward kernel for training that directly generates ground truth and attention output.
  • Figure 3: Inference Diagram of SeerAttention-R. During inference, a K Compression Cache is used to cache the compressed key representation in AttnGate to speedup sparse block prediction. This K Compression Cache only updates once per block number of tokens is generated (block=4 in the plots for illustration). As a result, the last block of sequence is always selected to compensate when the compression cache has not been updated yet. $g$ is the group size of GQA.
  • Figure 4: Oracle Sparse Results of Qwen3-14B with block size 32, 64, 128.
  • Figure 5: Accuracy Results of Full Attention, SeerAttention-R, and Quest. The Quest sparse configuration is set to be the same as SeerAttention-R for fair comparison, which uses a block size of 64 and sparse attention in all layers.
  • ...and 4 more figures