OBCache: Optimal Brain KV Cache Pruning for Efficient Long-Context LLM Inference
Yuzhe Gu, Xiyu Liang, Jiaojiao Zhao, Enmao Diao
TL;DR
OBCache addresses the memory bottleneck of long-context LLM inference by rethinking KV cache eviction as a structured pruning problem. It uses a second-order Taylor expansion under the Optimal Brain Damage framework to derive closed-form, output-aware token saliency scores, including $S_p^{value}$, $S_p^{key}$, and $S_p^{joint}$, which quantify the eviction impact on attention outputs. By incorporating value states, pre-softmax logits, and attention outputs, OBCache provides richer signals than prior attention-weight heuristics and encompasses them as special cases. Experiments on LLaMA-3.1 and Qwen-2.5 demonstrate consistent improvements in long-context accuracy and perplexity when OBCache scores are integrated into existing KV cache eviction pipelines.
Abstract
Large language models (LLMs) with extended context windows enable powerful downstream applications but impose significant memory overhead, as caching all key-value (KV) states scales linearly with sequence length and batch size. Existing cache eviction methods address this by exploiting attention sparsity, yet they typically rank tokens heuristically using accumulated attention weights without considering their true impact on attention outputs. We propose Optimal Brain Cache (OBCache), a principled framework that formulates cache eviction as a layer-wise structured pruning problem. Building upon the Optimal Brain Damage (OBD) theory, OBCache quantifies token saliency by measuring the perturbation in attention outputs induced by pruning tokens, with closed-form scores derived for isolated keys, isolated values, and joint key-value pairs. Our scores account not only for attention weights but also for information from value states and attention outputs, thereby enhancing existing eviction strategies with output-aware signals. Experiments on LLaMA and Qwen models demonstrate that replacing the heuristic scores in existing works, which estimate token saliency across different query positions, with OBCache's output-aware scores consistently improves long-context accuracy.
