Table of Contents
Fetching ...

Near-Oracle KV Selection via Pre-hoc Sparsity for Long-Context Inference

Yifei Gao, Lei Wang, Rong-Cheng Tu, Qixin Zhang, Jun Cheng, Dacheng Tao

TL;DR

This work tackles the long-context KV-attention bottleneck in autoregressive LLMs by introducing Pre-hoc Sparsity (PrHS), a framework that bounds information loss when selecting a small subset of KV entries before attention scoring. It formalizes an MI-based bound that depends solely on the dropped attention mass, enabling verifiable accuracy guarantees as the retained mass is controlled in advance. The authors instantiate three complementary selectors—Clustered Indices Sharing (CIS), Progressive Sliding Attention Window (PSAW), and Early Token Freezing (ETF)—and integrate them into a parallelizable system called CPE, achieving up to 90% KV-retrieval reduction while preserving near-top-k oracle accuracy. Extensive experiments on LLaMA and Mistral models across GSM8K, CoQA, and LongBench demonstrate substantial improvements in latency (up to 9.9x) and throughput (up to 2.8x) with comparable or better accuracy, highlighting the practical impact for scalable long-context inference.

Abstract

A core bottleneck in large language model (LLM) inference is the cost of attending over the ever-growing key-value (KV) cache. Although near-oracle top-k KV selection can preserve the quality of dense attention while sharply reducing computation and bandwidth, existing sparse methods generally rely on posterior heuristics, i.e., selectors conditioned on observed attention or proxy scores. Such conditioning introduces posterior bias: it tends to distort true token importance and miss salient tokens, thereby impairing long-range reasoning. To tackle this problem, we propose Pre-hoc Sparsity (PrHS), which selects KV entries before attention scoring and provides explicit accuracy control. Let the attention mass of discarded entries be delta (the dropped mass). Through a marginal-to-mutual-information analysis, we derive an upper bound on the mutual-information loss that depends only on the dropped mass. This relation explains failure modes of posterior heuristics and enables verifiable guarantees by controlling the dropped mass in advance. Within PrHS, we instantiate three orthogonal pre-hoc selectors along the axes of time, depth, and layer. Extensive experiments on LLaMA and Mistral families validate PrHS. Across GSM8K and CoQA, PrHS reduces retrieval overhead by over 90%, achieving 3x higher retrieval sparsity than HShare at matched or better accuracy. It incurs under 1% average degradation on LongBench, lowers attention FLOPs by about 15% versus prior sparse baselines, and yields a 9.9x speedup in attention-operator latency and 2.8x higher throughput on NVIDIA A100-80GB GPUs than the dense baseline.

Near-Oracle KV Selection via Pre-hoc Sparsity for Long-Context Inference

TL;DR

This work tackles the long-context KV-attention bottleneck in autoregressive LLMs by introducing Pre-hoc Sparsity (PrHS), a framework that bounds information loss when selecting a small subset of KV entries before attention scoring. It formalizes an MI-based bound that depends solely on the dropped attention mass, enabling verifiable accuracy guarantees as the retained mass is controlled in advance. The authors instantiate three complementary selectors—Clustered Indices Sharing (CIS), Progressive Sliding Attention Window (PSAW), and Early Token Freezing (ETF)—and integrate them into a parallelizable system called CPE, achieving up to 90% KV-retrieval reduction while preserving near-top-k oracle accuracy. Extensive experiments on LLaMA and Mistral models across GSM8K, CoQA, and LongBench demonstrate substantial improvements in latency (up to 9.9x) and throughput (up to 2.8x) with comparable or better accuracy, highlighting the practical impact for scalable long-context inference.

Abstract

A core bottleneck in large language model (LLM) inference is the cost of attending over the ever-growing key-value (KV) cache. Although near-oracle top-k KV selection can preserve the quality of dense attention while sharply reducing computation and bandwidth, existing sparse methods generally rely on posterior heuristics, i.e., selectors conditioned on observed attention or proxy scores. Such conditioning introduces posterior bias: it tends to distort true token importance and miss salient tokens, thereby impairing long-range reasoning. To tackle this problem, we propose Pre-hoc Sparsity (PrHS), which selects KV entries before attention scoring and provides explicit accuracy control. Let the attention mass of discarded entries be delta (the dropped mass). Through a marginal-to-mutual-information analysis, we derive an upper bound on the mutual-information loss that depends only on the dropped mass. This relation explains failure modes of posterior heuristics and enables verifiable guarantees by controlling the dropped mass in advance. Within PrHS, we instantiate three orthogonal pre-hoc selectors along the axes of time, depth, and layer. Extensive experiments on LLaMA and Mistral families validate PrHS. Across GSM8K and CoQA, PrHS reduces retrieval overhead by over 90%, achieving 3x higher retrieval sparsity than HShare at matched or better accuracy. It incurs under 1% average degradation on LongBench, lowers attention FLOPs by about 15% versus prior sparse baselines, and yields a 9.9x speedup in attention-operator latency and 2.8x higher throughput on NVIDIA A100-80GB GPUs than the dense baseline.
Paper Structure (81 sections, 20 theorems, 87 equations, 8 figures, 7 tables)

This paper contains 81 sections, 20 theorems, 87 equations, 8 figures, 7 tables.

Key Result

Theorem 1

Consider the scaled dot-product self-attention (cf. Eq. eq:attn_eq) with fixed keys $\{\mathbf{k}_j\}_{j=1}^t$, attention weights $A_j(\mathbf{q}_t)$ for key $\mathbf{k}_j$, and let $\{p_j\}_{j=1}^t$ be scalar token positions with $\mathrm{diam}\,\mathcal{P}:=\max_j p_j-\min_j p_j$ for $\mathcal{P}= i.e., the attention centroid moves in a Lipschitz-continuous fashion with respect to the query. In

Figures (8)

  • Figure 1: Overall analysis of performance and efficiency. Across induced attention-approximation error and accuracy-efficiency trade-offs, our method consistently surpasses prior SOTA approaches and closely tracks the top-$k$ oracle for optimal KV compression accuracy.
  • Figure 2: Distribution of critical indices in LLaMA2-7B-Chat (Layer 10, Head 2). For each query, we retrieve 64 critical indices using the top-$k$ oracle on WikiText2. Columns show five temporally adjacent queries with cosine similarity $>\!0.8$. Left: overall distribution across keys 400–710; clusters are outlined in blue boxes. Right: zoom-ins of three clusters (annotated by green arrows in the left) at keys 426–436, 591–700, and 691–704. Critical indices are in red.
  • Figure 3: Attention distribution on LLaMA2-7B-Chat. Heatmaps across representative layer–head pairs, with darker color indicating stronger attention.
  • Figure 4: Critical-index set (CIS) dilation. For three adjacent queries (405–407), we examine two key-index clusters $[120,125]$ and $[220,225]$. Critical tokens from query 405 are shared to queries 406 and 407. Overlap is measured as $\frac{\text{Number of True Positive}}{\text{Number of Critical Tokens}}$. Tokens in violet are also critical for query 405.
  • Figure 5: Visual illustration of PSAW and ETF in prefill stage. When $\ell<\ell_s$, attention is unchanged. For $\ell\ge\ell_s$, both methods prune redundant computation: PSAW computes a per-step sliding window, so the set of masked tokens can vary across steps (columns); ETF applies a fixed prune range and freezes earlier tokens so they no longer update.
  • ...and 3 more figures

Theorems & Definitions (37)

  • Theorem 1: Centroid Drift
  • Theorem 2: CIS Retained–Mass and MI Guarantee
  • Remark 1: We bound information via the index channel
  • Lemma 1: Total-variation of truncation
  • proof
  • Lemma 2: Continuity of mutual information (MI) under TV perturbations
  • proof
  • Proposition 1: Universal MI bounds & KL variant
  • proof
  • Theorem 3: Oracle top-$k$ information bound
  • ...and 27 more