Table of Contents
Fetching ...

AttentionPredictor: Temporal Patterns Matter for KV Cache Compression

Qingyue Yang, Jie Wang, Xing Li, Zhihai Wang, Chen Chen, Lei Chen, Xianzhi Yu, Wulong Liu, Jianye Hao, Mingxuan Yuan, Bin Li

TL;DR

AttentionPredictor addresses the KV cache memory bottleneck in long-context LLMs by learning a lightweight spatiotemporal predictor that directly forecasts next-step attention patterns to identify critical tokens. The method uses a compact CNN-based architecture shared across layers to model 2D attention dynamics, combined with a cross-token prefetching framework that hides prediction and transfer latency during decoding. It achieves substantial KV cache compression (approximately 13x) and decoding speedups (approximately 5.6x) with comparable LLM performance, and demonstrates strong generalization across datasets and tasks. This work enables more efficient long-context inference on large models and offers a practical approach to memory- and latency-constrained deployment.

Abstract

With the development of large language models (LLMs), efficient inference through Key-Value (KV) cache compression has attracted considerable attention, especially for long-context generation. To compress the KV cache, recent methods identify critical KV tokens through static modeling of attention scores. However, these methods often struggle to accurately determine critical tokens as they neglect the temporal patterns in attention scores, resulting in a noticeable degradation in LLM performance. To address this challenge, we propose AttentionPredictor, which is the first learning-based method to directly predict attention patterns for KV cache compression and critical token identification. Specifically, AttentionPredictor learns a lightweight, unified convolution model to dynamically capture spatiotemporal patterns and predict the next-token attention scores. An appealing feature of AttentionPredictor is that it accurately predicts the attention score and shares the unified prediction model, which consumes negligible memory, among all transformer layers. Moreover, we propose a cross-token critical cache prefetching framework that hides the token estimation time overhead to accelerate the decoding stage. By retaining most of the attention information, AttentionPredictor achieves 13$\times$ KV cache compression and 5.6$\times$ speedup in a cache offloading scenario with comparable LLM performance, significantly outperforming the state-of-the-arts. The code is available at https://github.com/MIRALab-USTC/LLM-AttentionPredictor.

AttentionPredictor: Temporal Patterns Matter for KV Cache Compression

TL;DR

AttentionPredictor addresses the KV cache memory bottleneck in long-context LLMs by learning a lightweight spatiotemporal predictor that directly forecasts next-step attention patterns to identify critical tokens. The method uses a compact CNN-based architecture shared across layers to model 2D attention dynamics, combined with a cross-token prefetching framework that hides prediction and transfer latency during decoding. It achieves substantial KV cache compression (approximately 13x) and decoding speedups (approximately 5.6x) with comparable LLM performance, and demonstrates strong generalization across datasets and tasks. This work enables more efficient long-context inference on large models and offers a practical approach to memory- and latency-constrained deployment.

Abstract

With the development of large language models (LLMs), efficient inference through Key-Value (KV) cache compression has attracted considerable attention, especially for long-context generation. To compress the KV cache, recent methods identify critical KV tokens through static modeling of attention scores. However, these methods often struggle to accurately determine critical tokens as they neglect the temporal patterns in attention scores, resulting in a noticeable degradation in LLM performance. To address this challenge, we propose AttentionPredictor, which is the first learning-based method to directly predict attention patterns for KV cache compression and critical token identification. Specifically, AttentionPredictor learns a lightweight, unified convolution model to dynamically capture spatiotemporal patterns and predict the next-token attention scores. An appealing feature of AttentionPredictor is that it accurately predicts the attention score and shares the unified prediction model, which consumes negligible memory, among all transformer layers. Moreover, we propose a cross-token critical cache prefetching framework that hides the token estimation time overhead to accelerate the decoding stage. By retaining most of the attention information, AttentionPredictor achieves 13 KV cache compression and 5.6 speedup in a cache offloading scenario with comparable LLM performance, significantly outperforming the state-of-the-arts. The code is available at https://github.com/MIRALab-USTC/LLM-AttentionPredictor.

Paper Structure

This paper contains 46 sections, 8 equations, 12 figures, 20 tables, 1 algorithm.

Figures (12)

  • Figure 1: A comparison of H2O, Quest, SeerAttention, and AttentionPredictor for identifying critical tokens in the next step with history attention score. Our learning-based spatiotemporal predictor captures the dynamic attention patterns and accurately predicts next-step attention scores.
  • Figure 2: Visualization of three predictable temporal attention patterns. Re-access shows repeated attention to specific tokens. Sequential shows attention progresses toward the next tokens. Seasonal exhibits periodic recurrence as alternating bands of high and uniform attention scores.
  • Figure 3: Overview of AttentionPredictor and cross-token prefetching framework. (a) AttentionPredictor formulates the history attention scores as a spatiotemporal sequence, and predicts the attention at the next step with a pre-trained model. To enhance efficiency, the attention history is updated in a compressed form at each decoding step. (b) The cross-token prefetching framework asynchronously evaluates critical tokens and fetches KV for the next token during the LLM inference, thereby accelerating the decoding stage.
  • Figure 3: The attention prediction accuracy (%) across different cache sizes.
  • Figure 4: Evaluation results on the long-output reasoning task AIME2024 with QwQ-32B.
  • ...and 7 more figures