Table of Contents
Fetching ...

Unveiling Simplicities of Attention: Adaptive Long-Context Head Identification

Konstantin Donhauser, Charles Arnal, Mohammad Pezeshki, Vivien Cabannes, David Lopez-Paz, Kartik Ahuja

TL;DR

The paper analyzes long-context attention in decoder-only transformers and discovers two head regimes: local-heads that rely on nearby tokens and long-context heads whose behavior depends on the query. It introduces QAdA, a query-adaptive criterion that uses second-order statistics of keys, via $\mu_K$ and $\\Sigma_K$, to predict which heads require long-context processing without computing full attention, enabling efficient sparsification. Across Llama, Qwen, and Mistral on benchmarks like RULER, LongBench, and long-context reasoning tasks, QAdA matches or exceeds static pruning performance and approaches oracle-like gains, while reducing run-time complexity to $O((1-\\rho) Td + \\rho T_{local} d)$. These results illuminate simple, robust patterns in attention behavior over long sequences and point to practical improvements in efficiency for long-context NLP tasks. The work advances mechanistic understanding of attention while offering a scalable path toward per-query head allocation and faster inference in large language models.

Abstract

The ability to process long contexts is crucial for many natural language processing tasks, yet it remains a significant challenge. While substantial progress has been made in enhancing the efficiency of attention mechanisms, there is still a gap in understanding how attention heads function in long-context settings. In this paper, we observe that while certain heads consistently attend to local information only, others swing between attending to local and long-context information depending on the query. This raises the question: can we identify which heads require long-context information to predict the next token accurately? We demonstrate that it's possible to predict which heads are crucial for long-context processing using only local keys. The core idea here is to exploit a simple model for the long-context scores via second moment approximations. These findings unveil simple properties of attention in the context of long sequences, and open the door to potentially significant gains in efficiency.

Unveiling Simplicities of Attention: Adaptive Long-Context Head Identification

TL;DR

The paper analyzes long-context attention in decoder-only transformers and discovers two head regimes: local-heads that rely on nearby tokens and long-context heads whose behavior depends on the query. It introduces QAdA, a query-adaptive criterion that uses second-order statistics of keys, via and , to predict which heads require long-context processing without computing full attention, enabling efficient sparsification. Across Llama, Qwen, and Mistral on benchmarks like RULER, LongBench, and long-context reasoning tasks, QAdA matches or exceeds static pruning performance and approaches oracle-like gains, while reducing run-time complexity to . These results illuminate simple, robust patterns in attention behavior over long sequences and point to practical improvements in efficiency for long-context NLP tasks. The work advances mechanistic understanding of attention while offering a scalable path toward per-query head allocation and faster inference in large language models.

Abstract

The ability to process long contexts is crucial for many natural language processing tasks, yet it remains a significant challenge. While substantial progress has been made in enhancing the efficiency of attention mechanisms, there is still a gap in understanding how attention heads function in long-context settings. In this paper, we observe that while certain heads consistently attend to local information only, others swing between attending to local and long-context information depending on the query. This raises the question: can we identify which heads require long-context information to predict the next token accurately? We demonstrate that it's possible to predict which heads are crucial for long-context processing using only local keys. The core idea here is to exploit a simple model for the long-context scores via second moment approximations. These findings unveil simple properties of attention in the context of long sequences, and open the door to potentially significant gains in efficiency.

Paper Structure

This paper contains 50 sections, 7 equations, 16 figures, 2 tables.

Figures (16)

  • Figure 1: Attention sparsity and its impact on efficiency. Left: Attention scores are split into bulk ($A^{\text{bulk}}$) for distant tokens and local window ($A^{\text{local}}$) for nearby ones. A head is considered local if most of its attention mass falls within the local window. The static criterion pre-assigns local heads, while the adaptive oracle query-dependently compares bulk and local contributions but is computationally expensive. Our approximation models $A^{\text{bulk}}$ using a Gaussian distribution for efficiency. Middle: Oracle-based classification with $\tau = 0.6$ (see Figure \ref{['fig:compare-approx']} for the threshold) reveals three types of heads: consistently local (heads labeled more than $95\%$ of the times as local), often long-context (less than $50\%$), and varying, which switch behavior dynamically. Right: Comparison of three methods: Static (green) removes a fixed fraction of heads, the adaptive oracle (blue) dynamically selects heads but is costly, and our adaptive method (purple) achieves near-oracle performance with significantly lower cost. As sparsity increases, static pruning degrades performance, while our adaptive method remains robust. These results show that most attention heads do not need to attend to the entire context, enabling significant efficiency gains with query-adaptive head classification.
  • Figure 2: Examples of attention score distributions for each possible outcome with $\tau_{\text{approx}} = \tau_{\text{oracle}} = 0.6$ with the oracle criterion as ground truth. We show histograms of scores from the local window $\mathcal{I}$ (brown) and the bulk complement $[T] \setminus \mathcal{I}$ (gray), along with the bulk Gaussian approximation (black dashed line). The annotations above each plot indicate the values taken by the statistics used for the oracle criterion and the adaptive criterion.
  • Figure 3: Comparison of QAdA against the adaptive and static oracles on the RULER benchmark. Left: For Llama 3-8B, we show the average performance (top) over the selected RULER 16k tasks as a function of the average sparsity for varying thresholds $\tau$, along with the worst-case performance drop (%) compared to the baseline performance among the selected tasks. Middle and Right: Average performance and worst-case drop for a fixed sparsity level of 0.85 across three model families—Llama, Mistral, and Qwen—on RULER 8k (center) and RULER 16k (right). Our adaptive criterion consistently matches or outperforms the static oracle criterion, and in some cases (e.g., Mistral), even achieves performance comparable to the adaptive oracle.
  • Figure 4: Top row: Similar to Figure \ref{['fig:compare-approx']}, we show the average performance for the LongBench benchmark, the pass@1 score for the MBPP task and the f1-score for the GSM8k task. Bottom row: Ablations for the content of the prompt (e-f) and the length of the prompt (g) used to generate the mean $\mu_K$ and covariance $\Sigma_K$ for the adaptive criterion from Section \ref{['sec:method']}. We show the normalized performance as a function of sparsity (e) for the "vt" task and (f) for the "fwe" task and (g) averaged over the RULER $8$k tasks, respectively.
  • Figure 5: a) The mean and standard deviation of the fraction of heads labeled as local heads as a function of time-steps for prompts from the "fwe" task. b) The average sparsity and standard deviation as a function of the threshold $\tau$ for Llama 3-8B over the RULER 8k and 16k, as well as the LongBench tasks. The annotations show the mean and standard deviation of the normalized performances (with $1$ being the performance of the standard dense attention). c) The average sparsities as a function of the threshold $\tau$, similar to those shown in b), are presented for each task, specifically for the QAdA criterion. Additionally, we present the average sparsity for a context-independent task. This task does not require context to be solved, and we observe that QAdA labels significantly more heads as local heads for the same threshold.
  • ...and 11 more figures