Table of Contents
Fetching ...

Linear Predictability of Attention Heads in Large Language Models

Khalid Shaikh, Asmit Kumar Singh, Rebecca Christopher Dsouza, Shikhar Shiromani

Abstract

Large language model (LLM) inference is increasingly bottlenecked by the Key-Value (KV) cache, yet the fine-grained structure of attention-head activations remains poorly understood. We show that pretrained Transformers exhibit a pervasive inter-head linear structure: for a given token, the Query, Key, and Value (QKV) vectors of an attention head can often be reconstructed as a linear combination of a small number of peer heads, typically within the same layer. Across Llama-3.1-8B, Falcon3-10B, OLMo-2-7B, and Qwen3-32B, just 2-5 reference heads recover many target heads with high fidelity (e.g., mean R^2 approx 0.76 for Keys on C4 with five references, and frequently R^2 > 0.85 on GSM8K). This predictability is learned rather than architectural: it is largely absent at random initialization, rises rapidly during pretraining as we track through OLMo-2 checkpoints, and is supported by a theoretical lower bound showing high mean-squared error for linear prediction at initialization. We further connect this emergence to increasing intra-layer alignment of Key projection subspaces. Finally, we exploit this redundancy for efficiency by caching only reference-head KV states and reconstructing the remaining heads on the fly via lightweight linear maps, achieving 2x KV-cache reduction with model-dependent accuracy trade-offs (4.5-5.5 percentage point average drop on Falcon3-10B and Qwen3-32B across five benchmarks, and larger drops on Llama-3.1-8B), and we find that reconstructing Keys is substantially less harmful than reconstructing Values.

Linear Predictability of Attention Heads in Large Language Models

Abstract

Large language model (LLM) inference is increasingly bottlenecked by the Key-Value (KV) cache, yet the fine-grained structure of attention-head activations remains poorly understood. We show that pretrained Transformers exhibit a pervasive inter-head linear structure: for a given token, the Query, Key, and Value (QKV) vectors of an attention head can often be reconstructed as a linear combination of a small number of peer heads, typically within the same layer. Across Llama-3.1-8B, Falcon3-10B, OLMo-2-7B, and Qwen3-32B, just 2-5 reference heads recover many target heads with high fidelity (e.g., mean R^2 approx 0.76 for Keys on C4 with five references, and frequently R^2 > 0.85 on GSM8K). This predictability is learned rather than architectural: it is largely absent at random initialization, rises rapidly during pretraining as we track through OLMo-2 checkpoints, and is supported by a theoretical lower bound showing high mean-squared error for linear prediction at initialization. We further connect this emergence to increasing intra-layer alignment of Key projection subspaces. Finally, we exploit this redundancy for efficiency by caching only reference-head KV states and reconstructing the remaining heads on the fly via lightweight linear maps, achieving 2x KV-cache reduction with model-dependent accuracy trade-offs (4.5-5.5 percentage point average drop on Falcon3-10B and Qwen3-32B across five benchmarks, and larger drops on Llama-3.1-8B), and we find that reconstructing Keys is substantially less harmful than reconstructing Values.
Paper Structure (19 sections, 2 theorems, 20 equations, 8 figures, 6 tables)

This paper contains 19 sections, 2 theorems, 20 equations, 8 figures, 6 tables.

Key Result

Theorem 3.1

Let ${\bm{A}}, {\bm{B}} \in \mathbb{R}^{m \times k}$ be two random matrices where each entry is sampled independently from $\mathcal{N}(0, 1)$ and $k \leq m/2$. Let ${\bm{x}} \in \mathbb{R}^{m}$ be an input vector drawn from $\mathcal{N}({\bm{0}}, {\bm{I}}_m)$ independently of ${\bm{A}}, {\bm{B}}$. Equivalently, with probability at least $1-2e^{-c_1 m}$ we have

Figures (8)

  • Figure 1: Linear predictability among attention heads enables significant KV-cache compression. (Left) Schematic of approximating a target head's (orange) Key activations using a learned linear projection from reference heads (green). (Middle) In a pretrained Llama-3.1 8B model, most heads' Key activations are highly predictable (high $R^2$) from a few peers on the C4 dataset raffel2023exploringlimitstransferlearning , indicating a shared low-rank subspace. (Right) This predictability allows for substantial KV-cache compressionwith minimal performance degradation on benchmarks like TruthfulQA, MMLU-STEM, and Winogrande. The dashed line marks 2$\times$ compression.
  • Figure 2: Linearity is pervasive—keys, queries, and values all become more predictable as the number of reference heads grows. Shown is the mean $R^{2}$ when reconstructing key, query, and value activations on a shared 150-sequence slice of C4 for three pretrained models: (Left) LLaMA-3.1 8B (Middle) Falcon-3 10B (Right) OLMo-2 7B. Across all panels, every stream (K, Q, V) shows a monotonic rise in $R^{2}$ as additional references are added, confirming that head-level representations in each model occupy a low-rank sub-space that can be captured with only a handful of heads.
  • Figure 3: Dominant predictors migrate from cross-layer to within-layer. For OLMo-2 7B we track, across checkpoints, where a head’s strongest linear predictors reside. (Left): percentage of heads whose single best ($R^{2}$-max) predictor lies in the same layer versus a different layer. (Right): the same percentages after aggregating each head’s five best predictors. At initialization almost all dominant links are inter-layer; by 50 k–100 k steps intra-layer links overtake and remain prevalent, indicating that pre-training progressively concentrates shared computation inside individual layers.
  • Figure 4: Left: Pretraining aligns head subspaces. Overlap dimension averaged across layers at different points during OLMo-2 training, indicating increasing intra-layer alignment. Right: Training induces linear predictability. Kernel-density estimate of $R^{2}$ for LLaMA-3.1 8B shows a broad moderate–high peak after training versus near-zero values at random initialization.
  • Figure 5: Pipeline for KV-cache compression via head-level redundancy. (1) We log key/value activations on a calibration set. (2) Pair-wise linear probes quantify how well one head predicts another. (3) A thresholded graph selects a minimal reference set that covers all targets. (4) Compact weight matrices $W_{r\!\to t}$ are trained for those links. (5) During inference, we cache only the reference heads and reconstruct the others on-the-fly, shrinking memory with minimal extra compute.
  • ...and 3 more figures

Theorems & Definitions (3)

  • Theorem 3.1
  • proof : Proof Sketch.
  • Theorem C.1: Hanson-Wright inequality Vershynin_2018