Table of Contents
Fetching ...

In-context KV-Cache Eviction for LLMs via Attention-Gate

Zihao Zeng, Bokai Lin, Tianqi Hou, Hao Zhang, Zhijie Deng

TL;DR

This work tackles the KV-Cache bottleneck in large language model inference by introducing Attention-Gate (AG), a lightweight, trainable module placed before each self-attention layer that, using global contextual information, outputs eviction flags to selectively prune KV states. AG leverages a reduced multi-head attention mechanism $\text{MHA}'$ to produce eviction decisions and a gating function to determine retention per token and head, enabling flexible, per-head, per-layer eviction with minimal overhead. Training uses an Eviction Loss to target a desired eviction rate and employs STE to handle the discrete gating, achieving strong results across continual pre-training and supervised fine-tuning benchmarks while delivering memory efficiency gains. The approach demonstrates robust, context-aware token eviction that can improve both inference efficiency and model performance, with practical implications for scaling LLMs to longer contexts. Overall, AG provides a scalable method to dynamically manage KV-Cache contents, preserving crucial tokens while discarding redundancies in a task- and context-sensitive manner.

Abstract

The KV-Cache technique has become the standard for the inference of large language models (LLMs). Yet, it is widely criticized that KV-Cache can become a bottleneck of the LLM inference system. This paper enables a novel dynamic KV-Cache eviction policy by injecting a lightweight module called Attention-Gate to the model. It accepts the global context as input and yields eviction flags for each token. The self-attention modules in the model proceed according to the flags and cache only a subset of the KV states for next token prediction. The Attention-Gates can yield various flags for different heads and layers and be easily tuned on top of a pre-trained LLM via continual pre-training or supervised fine-tuning. The computational and memory overhead introduced by Attention-Gates can be minimal. We empirically evaluate the proposed approach across multiple scenarios, showing that effective eviction of redundant tokens can not only improve efficiency but also enhance performance.

In-context KV-Cache Eviction for LLMs via Attention-Gate

TL;DR

This work tackles the KV-Cache bottleneck in large language model inference by introducing Attention-Gate (AG), a lightweight, trainable module placed before each self-attention layer that, using global contextual information, outputs eviction flags to selectively prune KV states. AG leverages a reduced multi-head attention mechanism to produce eviction decisions and a gating function to determine retention per token and head, enabling flexible, per-head, per-layer eviction with minimal overhead. Training uses an Eviction Loss to target a desired eviction rate and employs STE to handle the discrete gating, achieving strong results across continual pre-training and supervised fine-tuning benchmarks while delivering memory efficiency gains. The approach demonstrates robust, context-aware token eviction that can improve both inference efficiency and model performance, with practical implications for scaling LLMs to longer contexts. Overall, AG provides a scalable method to dynamically manage KV-Cache contents, preserving crucial tokens while discarding redundancies in a task- and context-sensitive manner.

Abstract

The KV-Cache technique has become the standard for the inference of large language models (LLMs). Yet, it is widely criticized that KV-Cache can become a bottleneck of the LLM inference system. This paper enables a novel dynamic KV-Cache eviction policy by injecting a lightweight module called Attention-Gate to the model. It accepts the global context as input and yields eviction flags for each token. The self-attention modules in the model proceed according to the flags and cache only a subset of the KV states for next token prediction. The Attention-Gates can yield various flags for different heads and layers and be easily tuned on top of a pre-trained LLM via continual pre-training or supervised fine-tuning. The computational and memory overhead introduced by Attention-Gates can be minimal. We empirically evaluate the proposed approach across multiple scenarios, showing that effective eviction of redundant tokens can not only improve efficiency but also enhance performance.

Paper Structure

This paper contains 25 sections, 11 equations, 5 figures, 6 tables.

Figures (5)

  • Figure 1: KV-Cache eviction patterns across different layers and attention-heads, visualized for 4 samples from the PIQA dataset (top row) and 4 samples from the BoolQ dataset (bottom row), using AG fine-tuned Llama2-7B models. Black areas represent tokens that are neither computed nor stored in the KV-Cache. The variability of eviction patterns across tasks, prompts, layers, and attention-heads demonstrates the dynamic nature of our method. A common trend observed is that deeper layers tend to mask more KV-Cache states, with some in deeper layers being entirely masked.
  • Figure 2: An overview of Attention-Gate (AG) for KV-Cache eviction. AG is a lightweight learnable module placed before each MHA layer. Given the input hidden states, it determines for each head whether to retain or discard the key and value tokens in the KV-Cache. In the attention weights, this corresponds to masking out columns for the evicted keys, while keeping the diagonal intact to ensure the query interacts with its own key.
  • Figure 3: Comparison of peak memory usage and prefilling time between the LLaMA2-7B model (without AG) and the proposed implementation (with AG and $\sim$50% eviction) across varying prompt lengths. The results show significant improvements in memory efficiency with AG, especially as prompt length increases. Prefilling time is not the primary focus, and the current implementation (marked with * in the legend) relies on a suboptimal for-loop over attention heads. Even so, the method maintains stable prefilling time and shows a clear reduction trend with longer prompts.
  • Figure 4: This visualization highlights attention patterns in Llama2-7B after fine-tuning on the BoolQ dataset, showcasing multiple heads within both MHA and AG across different layers using a selected sample. In part (i), we visualize attention scores from several MHA heads across layers before eviction: 1. MHA heads display diverse attention patterns, especially in the first two layers, where heterogeneity is prominent. 2. Attention patterns become progressively sparser in deeper layers, transitioning from dense in early layers. 3. Bright-yellow vertical lines, indicating critical tokens for inference, appear consistently across heads in deeper layers. These align with the Heavy Hitters in H2O zhang2024h2o, highlighting tokens that significantly contribute to attention scores. Our method ensures these critical tokens are preserved in deeper layers, maintaining their importance across the network. In part (ii), we visualize the attention-like scores from the AG mechanism: As layers deepen, AG shifts from high-resolution to lower-resolution attention, focusing on distilling in-context information. Deeper AG layers no longer require high resolution for capturing global context, as earlier layers have already refined it. This suggests potential optimizations, such as reducing the number of heads or dimensions in deeper AG layers, to enhance efficiency.
  • Figure 5: The complete version of \ref{['fig:visual']}. Two key observations emerge in (i): 1. the first two layers are denser compared to the subsequent layers, and 2. bright-yellow vertical lines, representing critical tokens for inference, consistently appear across heads in deeper layers.