Table of Contents
Fetching ...

TokenButler: Token Importance is Predictable

Yash Akhauri, Ahmed F AbouElhamayed, Yifei Gao, Chi-Chih Chang, Nilesh Jain, Mohamed S. Abdelfattah

TL;DR

TokenButler tackles the KV-Cache memory bottleneck in long-context LLMs by introducing a lightweight, per-head token-importance predictor that approximates full attention logits at decode-time. Trained with a mean-squared-error objective against true pre-softmax logits while the LLM is frozen, it delivers fine-grained, query-aware token prioritization with negligible overhead (<$2\%$ latency) and parameter cost (<$1.2\%$ of the LLM). Empirical results show up to 8% improvements in perplexity and downstream accuracy over state-of-the-art token-importance methods, and strong performance on a synthetic co-reference benchmark, standard benchmarks, and reasoning models. The work demonstrates that per-head, high-granularity token selection can preserve essential context in co-reference and complex reasoning tasks while reducing memory bandwidth demands, pointing to practical deployment of decode-time token loading strategies guided by learned importance.

Abstract

Large Language Models (LLMs) rely on the Key-Value (KV) Cache to store token history, enabling efficient decoding of tokens. As the KV-Cache grows, it becomes a major memory and computation bottleneck, however, there is an opportunity to alleviate this bottleneck, especially because prior research has shown that only a small subset of tokens contribute meaningfully to each decoding step. A key challenge in finding these critical tokens is that they are dynamic, and heavily input query-dependent. Existing methods either risk quality by evicting tokens permanently, or retain the full KV-Cache but rely on retrieving chunks (pages) of tokens at generation, failing at dense, context-rich tasks. Additionally, many existing KV-Cache sparsity methods rely on inaccurate proxies for token importance. To address these limitations, we introduce TokenButler, a high-granularity, query-aware predictor that learns to identify these critical tokens. By training a light-weight predictor with less than 1.2% parameter overhead, TokenButler prioritizes tokens based on their contextual, predicted importance. This improves perplexity & downstream accuracy by over 8% relative to SoTA methods for estimating token importance. We evaluate TokenButler on a novel synthetic small-context co-referential retrieval task, demonstrating near-oracle accuracy. Code, models and benchmarks: https://github.com/abdelfattah-lab/TokenButler

TokenButler: Token Importance is Predictable

TL;DR

TokenButler tackles the KV-Cache memory bottleneck in long-context LLMs by introducing a lightweight, per-head token-importance predictor that approximates full attention logits at decode-time. Trained with a mean-squared-error objective against true pre-softmax logits while the LLM is frozen, it delivers fine-grained, query-aware token prioritization with negligible overhead (< latency) and parameter cost (< of the LLM). Empirical results show up to 8% improvements in perplexity and downstream accuracy over state-of-the-art token-importance methods, and strong performance on a synthetic co-reference benchmark, standard benchmarks, and reasoning models. The work demonstrates that per-head, high-granularity token selection can preserve essential context in co-reference and complex reasoning tasks while reducing memory bandwidth demands, pointing to practical deployment of decode-time token loading strategies guided by learned importance.

Abstract

Large Language Models (LLMs) rely on the Key-Value (KV) Cache to store token history, enabling efficient decoding of tokens. As the KV-Cache grows, it becomes a major memory and computation bottleneck, however, there is an opportunity to alleviate this bottleneck, especially because prior research has shown that only a small subset of tokens contribute meaningfully to each decoding step. A key challenge in finding these critical tokens is that they are dynamic, and heavily input query-dependent. Existing methods either risk quality by evicting tokens permanently, or retain the full KV-Cache but rely on retrieving chunks (pages) of tokens at generation, failing at dense, context-rich tasks. Additionally, many existing KV-Cache sparsity methods rely on inaccurate proxies for token importance. To address these limitations, we introduce TokenButler, a high-granularity, query-aware predictor that learns to identify these critical tokens. By training a light-weight predictor with less than 1.2% parameter overhead, TokenButler prioritizes tokens based on their contextual, predicted importance. This improves perplexity & downstream accuracy by over 8% relative to SoTA methods for estimating token importance. We evaluate TokenButler on a novel synthetic small-context co-referential retrieval task, demonstrating near-oracle accuracy. Code, models and benchmarks: https://github.com/abdelfattah-lab/TokenButler

Paper Structure

This paper contains 15 sections, 3 equations, 9 figures, 2 tables.

Figures (9)

  • Figure 1: Full-Attention preserves all tokens, enabling access to the critical token (dark green) during the last decode step. Static strategies like StreamingLLM will not be able to access this token. Methods like $H_{2}O$ may have evicted the token at an earlier decode step, if deemed unimportant. Paged-Token importance may cause a page-miss of a critical token in context dense tasks. TokenButler can effectively predict critical tokens, and can be leveraged by existing methods to offer both high-granularity and cheap importance estimation.
  • Figure 2: TokenButler is a light-weight predictor, with a down projection $D_{\mathrm{proj}}$ for cheaper attention, attention layer, and Key-Query projection neural networks. These $\{Q_{\mathrm{imp}}, K_{\mathrm{imp}}\}$ effectively map the output of the attention mechanism to $N\times H$ Key-Query projection tensors (N: Num. Layers, H: Num. Heads) on a small interaction-dimension $d \ll E$. The full-(pre-softmax) attention logits can then be computed for every head across all layers by taking Product($QK^{T}$). At train-time, we simply minimize the MSE Error between true and prediction (pre-softmax) attention logits to learn the LLM behavior.
  • Figure 3: We train predictors within a $[1, 1.2]\%$ parameter count budget compared to the target LLM. Token Classification Accuracy indicates the classification accuracy in identifying the top 50% most important tokens. We achieve between 70-75% classification accuracy
  • Figure 4: We measure the overhead of our predictor across sequence lengths, reporting the median over 100 iterations on a Nvidia A6000 GPU for Llama models. We do not implement any sparsity or optimizations, simply the overhead of prefill-token importance estimation. Our predictor scales with the model size ($\approx 1\%$), and thus adds $< 2\%$ latency overhead.
  • Figure 5: For different methods, we present the true decode-simulation of token access patterns for a sample head on Llama-3.2-3B. The tokens that are accessed at the decode step for location ziramelgrove are underlined. The tokens that were mis-predicted at the last step of the decode simulation are in red. The positions of the location is in green, and if the final location is mis-predicted, that token is striked-out in red. Coverage (number of correctly predicted tokens) is excellent for both Oracle and TokenButler.
  • ...and 4 more figures