Table of Contents
Fetching ...

Critical attention scaling in long-context transformers

Shi Chen, Zhengjiang Lin, Yury Polyanskiy, Philippe Rigollet

TL;DR

This paper addresses rank-collapse in long-context transformers by analyzing a tractable self-attention model with $a_{ij}=\beta_n\langle y_i,y_j\rangle$ and a pre-layer-norm setup, proving that the critical scaling occurs at $β_n\asymp\log n$ (via $β_n=\gamma\log n$). It characterizes forward dynamics through a phase transition with subcritical, critical, and supercritical regimes around $\gamma=1/(1-\rho)$ (and analogous thresholds in the relaxed setting), showing when attention contractively aggregates tokens, sparsely concentrates interactions, or acts nearly as the identity. The backward pass exhibits a matching transition in gradient propagation, with vanishing gradients in the subcritical regime and nontrivial gradients in the supercritical regime, supported by explicit asymptotics. Numerical experiments on almost-simplex token configurations validate the theoretical predictions, revealing sharp phase transitions in large dimensions and a middle phase in intermediate regimes. Together, these results provide a rigorous justification for logarithmic attention scaling and its role in enabling content-adaptive, sparse interactions while preserving trainability at long context lengths.

Abstract

As large language models scale to longer contexts, attention layers suffer from a fundamental pathology: attention scores collapse toward uniformity as context length $n$ increases, causing tokens to cluster excessively, a phenomenon known as rank-collapse. While $\textit{attention scaling}$ effectively addresses this deficiency by rescaling attention scores with a polylogarithmic factor $β_n$, theoretical justification for this approach remains lacking. We analyze a simplified yet tractable model that magnifies the effect of attention scaling. In this model, attention exhibits a phase transition governed by the scaling factor $β_n$: insufficient scaling collapses all tokens to a single direction, while excessive scaling reduces attention to identity, thereby eliminating meaningful interactions between tokens. Our main result identifies the critical scaling $β_n \asymp \log n$ and provides a rigorous justification for attention scaling in YaRN and Qwen, clarifying why logarithmic scaling maintains sparse, content-adaptive attention at large context lengths.

Critical attention scaling in long-context transformers

TL;DR

This paper addresses rank-collapse in long-context transformers by analyzing a tractable self-attention model with and a pre-layer-norm setup, proving that the critical scaling occurs at (via ). It characterizes forward dynamics through a phase transition with subcritical, critical, and supercritical regimes around (and analogous thresholds in the relaxed setting), showing when attention contractively aggregates tokens, sparsely concentrates interactions, or acts nearly as the identity. The backward pass exhibits a matching transition in gradient propagation, with vanishing gradients in the subcritical regime and nontrivial gradients in the supercritical regime, supported by explicit asymptotics. Numerical experiments on almost-simplex token configurations validate the theoretical predictions, revealing sharp phase transitions in large dimensions and a middle phase in intermediate regimes. Together, these results provide a rigorous justification for logarithmic attention scaling and its role in enabling content-adaptive, sparse interactions while preserving trainability at long context lengths.

Abstract

As large language models scale to longer contexts, attention layers suffer from a fundamental pathology: attention scores collapse toward uniformity as context length increases, causing tokens to cluster excessively, a phenomenon known as rank-collapse. While effectively addresses this deficiency by rescaling attention scores with a polylogarithmic factor , theoretical justification for this approach remains lacking. We analyze a simplified yet tractable model that magnifies the effect of attention scaling. In this model, attention exhibits a phase transition governed by the scaling factor : insufficient scaling collapses all tokens to a single direction, while excessive scaling reduces attention to identity, thereby eliminating meaningful interactions between tokens. Our main result identifies the critical scaling and provides a rigorous justification for attention scaling in YaRN and Qwen, clarifying why logarithmic scaling maintains sparse, content-adaptive attention at large context lengths.

Paper Structure

This paper contains 11 sections, 20 theorems, 119 equations, 2 figures, 1 table.

Key Result

Theorem 2.1

Under Assumption a:simplex, there is a $\rho' \in (0,1)$ such that $\langle y_i ' , y_j '\rangle=\rho'$ for all $i\neq j$. Moreover, if $\beta = \gamma \log n$ where $\gamma$ is a positive constant, then for any $i\neq j$, it holds

Figures (2)

  • Figure 1: Plots of the input-to-output angle ratio $\lambda$, defined in \ref{['eqn:input_output_ratio']}, as a function of $\rho$ and $\gamma$. The tokens are first normalized by a pre-layer normalization and then passed through a single self-attention layer \ref{['e:att']}, with residual connections and MLP layers omitted. The dashed curve corresponds to $\gamma=\tfrac{1}{1-\rho}$, which approximates the actual phase transition with increasing accuracy as $d$ grows, as implied by Theorem \ref{['thm:att limit phase 2']}.
  • Figure 2: Plots of the normalized norm $\eta$ of the gradient, defined by \ref{['eqn:normalized_norm']}, as a function of $\rho$ and $\gamma$. The tokens are first normalized by a pre-layer normalization and then passed through a single self-attention layer \ref{['e:att']}, with residual connections and MLP layers omitted. The dash curve shows $\frac{1}{1-\rho}$, which approximate the actual phase transition with increasing accuracy as $d$ grows, as implied by Theorem \ref{['thm:att propagation gradient norm 2']}. The matrix norm $\eta$ is computed by the Hutchinson trace estimator hutchinson1989stochastic, based on the definition in \ref{['e:def jacobian norm']}.

Theorems & Definitions (40)

  • Theorem 2.1
  • proof : Proof of \ref{['e:simplex_phase_alpha=0']}
  • Theorem 2.2
  • Theorem 2.3
  • Theorem 2.4
  • Lemma A.1
  • proof : Proof of Lemma \ref{['lem:Z_i asymptotics']}
  • Lemma A.2
  • proof : Proof of Lemma \ref{['lem:gamma large single point']}
  • Lemma A.3
  • ...and 30 more