Table of Contents
Fetching ...

RaaS: Reasoning-Aware Attention Sparsity for Efficient LLM Reasoning

Junhao Hu, Wenrui Huang, Weidong Wang, Zhenwen Li, Tiancheng Hu, Zhixia Liu, Xusheng Chen, Tao Xie, Yizhou Shan

TL;DR

RaaS identifies a milestone- and phoenix-token attention pattern during the decode stage of reasoning tasks and leverages this pattern to design a sparsity-based KV cache strategy. By retaining milestone tokens with an LR U-based timestamping and preserving prefill tokens, RaaS achieves $O(L)$ time and $O(L)$ memory while maintaining accuracy comparable to the state-of-the-art Quest. A page-based variant further aligns with efficient kernels, yielding practical deployment with constant memory usage. The results across multiple math datasets and models demonstrate that RaaS offers strong accuracy and latency with significantly reduced memory footprints, suggesting a viable path for scalable long-decode inference in reasoning-heavy applications.

Abstract

Large Language Models (LLMs) have demonstrated strong capabilities across various domains, with recent advancements in challenging reasoning tasks such as mathematics and programming. However, solving reasoning tasks often requires an LLM to generate long sequences, incurring $O(N)$ time and memory complexities per token, where $N$ is the current sequence length. To reduce complexities, existing sparsity-based algorithms propose to retain Key-Value (KV) vectors, the intermediate representations of only the most critical tokens. However, these algorithms struggle with the "impossible trinity" of accuracy, time, and memory. For example, the state-of-the-art algorithm, Quest, achieves high accuracy with $O(L)$ time but $O(N)$ memory ($L$ is the cache budget, $L \ll N$). To address the "impossible trinity", in this paper, we identify a new attention pattern during the decode stage of reasoning tasks, where milestone tokens (analogous to lemmas in mathematical proofs) emerge, are utilized, and then become unimportant afterward. Based on this pattern, we propose a new algorithm RaaS that identifies milestone tokens and retains their KV vectors until they are no longer needed, achieving high accuracy with $O(L)$ time and $O(L)$ memory complexities.

RaaS: Reasoning-Aware Attention Sparsity for Efficient LLM Reasoning

TL;DR

RaaS identifies a milestone- and phoenix-token attention pattern during the decode stage of reasoning tasks and leverages this pattern to design a sparsity-based KV cache strategy. By retaining milestone tokens with an LR U-based timestamping and preserving prefill tokens, RaaS achieves time and memory while maintaining accuracy comparable to the state-of-the-art Quest. A page-based variant further aligns with efficient kernels, yielding practical deployment with constant memory usage. The results across multiple math datasets and models demonstrate that RaaS offers strong accuracy and latency with significantly reduced memory footprints, suggesting a viable path for scalable long-decode inference in reasoning-heavy applications.

Abstract

Large Language Models (LLMs) have demonstrated strong capabilities across various domains, with recent advancements in challenging reasoning tasks such as mathematics and programming. However, solving reasoning tasks often requires an LLM to generate long sequences, incurring time and memory complexities per token, where is the current sequence length. To reduce complexities, existing sparsity-based algorithms propose to retain Key-Value (KV) vectors, the intermediate representations of only the most critical tokens. However, these algorithms struggle with the "impossible trinity" of accuracy, time, and memory. For example, the state-of-the-art algorithm, Quest, achieves high accuracy with time but memory ( is the cache budget, ). To address the "impossible trinity", in this paper, we identify a new attention pattern during the decode stage of reasoning tasks, where milestone tokens (analogous to lemmas in mathematical proofs) emerge, are utilized, and then become unimportant afterward. Based on this pattern, we propose a new algorithm RaaS that identifies milestone tokens and retains their KV vectors until they are no longer needed, achieving high accuracy with time and memory complexities.

Paper Structure

This paper contains 20 sections, 10 figures, 1 algorithm.

Figures (10)

  • Figure 1: The Cumulative Distribution Function (CDF) of sequence lengths for the Prefill (P) and Decode (D) stages for (a) five datasets from LongBench yushi2024longbench and (b) three math datasets running on the reasoning-enabled Marco-O1 model. (c) The breakdown of prefill and decode time during the inference of fixed 32k tokens using vLLM 0.6.1 with the LLaMA 3.1 8B model in FP16 precision. As the number of decode tokens increases (with the number of prefill tokens being 32k minus the decode tokens), the decode time rises significantly faster than the prefill time.
  • Figure 2: Comparison of sparsity-based algorithms. $N$ indicates the sequence length while $L$ indicates the cache budget where $L\ll N$. Asterisks on H2O's time and memory complexities indicate theoretical complexities that are not realized in practical implementations. RaaS addresses the "impossible trinity" by achieving $O(L)$ complexity for both time and memory, with accuracy comparable to Dense on reasoning tasks. Refer to Section \ref{['sec-background']} for detailed explanations of each algorithm's design.
  • Figure 3: A new attention pattern emerges in reasoning tasks. We manually inspect attention maps across 28 layers and 28 heads of Qwen2.5-Math-7B-Instruct qwen25mathtechnical on 100 MATH500 math500 test cases. We find that (a) $24.2\%$ maps with milestone tokens, (b) $1.5\%$ maps with phoenix tokens (with a 64-token cache budget), (c) more than 70% "lazy" zhang2025lighttransfer maps with StreamingLLM pattern. We use our best effort to balance the clarity and completeness of long-decode attention maps.
  • Figure 4: We input the prefill tokens, "...Convert the point $(0,3)$ to polar coordinates...", to Qwen2.5-Math-7B-Instruct and obtain the corresponding decode tokens in the figure. The red tokens represent the milestone tokens or bright columns in Figure \ref{['fig-algo-waterfall']} (a).
  • Figure 5: Accuracy vs. cache budget for five algorithms (legends) across three datasets (rows) and four models (columns). The y-axis shows the proportion of correctly solved problems among 200 test cases, while the x-axis represents varying cache budgets: 64, 128, 256, 512, and 1024.
  • ...and 5 more figures