Exploiting Sparsity for Long Context Inference: Million Token Contexts on Commodity GPUs
Ryan Synk, Monte Hoover, John Kirchenbauer, Neel Jain, Alex Stein, Manli Shu, Josue Melendez Sanchez, Ramani Duraiswami, Tom Goldstein
TL;DR
This work tackles the prohibitive memory and compute costs of self-attention for million-token contexts on commodity GPUs by introducing a tunable top-$k$ attention mechanism with a CPU-resident KV cache. By retrieving only the top-$k$ keys per layer via approximate nearest-neighbor search during decoding, the method achieves $O(k)$ GPU memory growth instead of $O(N)$, enabling million-token generation with around $16\,\mathrm{GB}$ of GPU RAM and near 95% of dense-attention performance when $k\approx 0.02N$. Empirical results across RULER, OpenLLM Leaderboard, and AlpacaEval demonstrate that very small $k$ values suffice for diverse tasks, and a million-token generation on a single GPU is feasible using a prefilled KV cache. The findings highlight the practical viability of long-context inference on commodity hardware and motivate adaptive, layer-wise budget strategies to further optimize compute-performance tradeoffs in deployment contexts.
Abstract
There is growing demand for performing inference with hundreds of thousands of input tokens on trained transformer models. Inference at this extreme scale demands significant computational resources, hindering the application of transformers at long contexts on commodity (i.e not data center scale) hardware. To address the inference time costs associated with running self-attention based transformer language models on long contexts and enable their adoption on widely available hardware, we propose a tunable mechanism that reduces the cost of the forward pass by attending to only the most relevant tokens at every generation step using a top-k selection mechanism. We showcase the efficiency gains afforded by our method by performing inference on context windows up to 1M tokens using approximately 16GB of GPU RAM. Our experiments reveal that models are capable of handling the sparsity induced by the reduced number of keys and values. By attending to less than 2% of input tokens, we achieve over 95% of model performance on common benchmarks (RULER, AlpacaEval, and Open LLM Leaderboard).
