Table of Contents
Fetching ...

Exploiting Sparsity for Long Context Inference: Million Token Contexts on Commodity GPUs

Ryan Synk, Monte Hoover, John Kirchenbauer, Neel Jain, Alex Stein, Manli Shu, Josue Melendez Sanchez, Ramani Duraiswami, Tom Goldstein

TL;DR

This work tackles the prohibitive memory and compute costs of self-attention for million-token contexts on commodity GPUs by introducing a tunable top-$k$ attention mechanism with a CPU-resident KV cache. By retrieving only the top-$k$ keys per layer via approximate nearest-neighbor search during decoding, the method achieves $O(k)$ GPU memory growth instead of $O(N)$, enabling million-token generation with around $16\,\mathrm{GB}$ of GPU RAM and near 95% of dense-attention performance when $k\approx 0.02N$. Empirical results across RULER, OpenLLM Leaderboard, and AlpacaEval demonstrate that very small $k$ values suffice for diverse tasks, and a million-token generation on a single GPU is feasible using a prefilled KV cache. The findings highlight the practical viability of long-context inference on commodity hardware and motivate adaptive, layer-wise budget strategies to further optimize compute-performance tradeoffs in deployment contexts.

Abstract

There is growing demand for performing inference with hundreds of thousands of input tokens on trained transformer models. Inference at this extreme scale demands significant computational resources, hindering the application of transformers at long contexts on commodity (i.e not data center scale) hardware. To address the inference time costs associated with running self-attention based transformer language models on long contexts and enable their adoption on widely available hardware, we propose a tunable mechanism that reduces the cost of the forward pass by attending to only the most relevant tokens at every generation step using a top-k selection mechanism. We showcase the efficiency gains afforded by our method by performing inference on context windows up to 1M tokens using approximately 16GB of GPU RAM. Our experiments reveal that models are capable of handling the sparsity induced by the reduced number of keys and values. By attending to less than 2% of input tokens, we achieve over 95% of model performance on common benchmarks (RULER, AlpacaEval, and Open LLM Leaderboard).

Exploiting Sparsity for Long Context Inference: Million Token Contexts on Commodity GPUs

TL;DR

This work tackles the prohibitive memory and compute costs of self-attention for million-token contexts on commodity GPUs by introducing a tunable top- attention mechanism with a CPU-resident KV cache. By retrieving only the top- keys per layer via approximate nearest-neighbor search during decoding, the method achieves GPU memory growth instead of , enabling million-token generation with around of GPU RAM and near 95% of dense-attention performance when . Empirical results across RULER, OpenLLM Leaderboard, and AlpacaEval demonstrate that very small values suffice for diverse tasks, and a million-token generation on a single GPU is feasible using a prefilled KV cache. The findings highlight the practical viability of long-context inference on commodity hardware and motivate adaptive, layer-wise budget strategies to further optimize compute-performance tradeoffs in deployment contexts.

Abstract

There is growing demand for performing inference with hundreds of thousands of input tokens on trained transformer models. Inference at this extreme scale demands significant computational resources, hindering the application of transformers at long contexts on commodity (i.e not data center scale) hardware. To address the inference time costs associated with running self-attention based transformer language models on long contexts and enable their adoption on widely available hardware, we propose a tunable mechanism that reduces the cost of the forward pass by attending to only the most relevant tokens at every generation step using a top-k selection mechanism. We showcase the efficiency gains afforded by our method by performing inference on context windows up to 1M tokens using approximately 16GB of GPU RAM. Our experiments reveal that models are capable of handling the sparsity induced by the reduced number of keys and values. By attending to less than 2% of input tokens, we achieve over 95% of model performance on common benchmarks (RULER, AlpacaEval, and Open LLM Leaderboard).

Paper Structure

This paper contains 29 sections, 3 equations, 13 figures, 3 tables, 1 algorithm.

Figures (13)

  • Figure 1: (top) Typical attention requires each query vector to compute an inner product with each key vector in the context window. In practice, most key vectors produce insignificant attention scores, and therefore contribute very little to subsequent hidden states, so much of this computation is wasted. (left-bottom) top-k attention retrieves only the keys that contribute significantly to the attention computation, leaving the gray arrows out and achieving sublinear runtime.
  • Figure 2: Performance on selected OpenLLM Leaderboard tasks using only the top-k keys for each attention computation. Typical questions have a context length of $\sim1000,$ yet only 10 keys are needed to achieve the same performance as full attention. We evaluate an extended set of models in \ref{['fig:lm_eval_harness_topk_results']}.
  • Figure 3: We analyze the number of attention scores that correspond to $75\%$ of the probability mass for generating the next token. Each point is the number of scores of the last 'row' from the attention matrix required to reach $75\%$ of the total attention. We observe each of the $32$ heads across $50$ samples. On the left, we plot the histogram for the first layer in the network, and on the right we plot it for layer 16.
  • Figure 4: Average attention score per token in a given Wikipedia article for a multi-article long context sample (across all heads and layers). The bar for the article prompted to be copied is highlighted in red. Note how the model focuses its attention on the correct document.
  • Figure 5: Entropy of the distribution of attention scores for each layer of a model, calculated as $E = - \sum_{i=1}^N a_i\log(a_i)$, where $(a_1,\dots,a_N)$ is the attention score distribution. Attention score distributions are derived from last token of 50 1000-token samples and aggregated over all heads for a given layer. Entropy serves as a measure of how concentrated the attention scores are for a given query token: low entropy indicates a large amount of attention centered over few tokens, and high entropy indicates a more uniform dispersion of attention scores. Maximum entropy occurs when the distribution is perfectly uniform, and for 1000-token contexts is $-\log(\frac{1}{1000})$.
  • ...and 8 more figures