Critical attention scaling in long-context transformers
Shi Chen, Zhengjiang Lin, Yury Polyanskiy, Philippe Rigollet
TL;DR
This paper addresses rank-collapse in long-context transformers by analyzing a tractable self-attention model with $a_{ij}=\beta_n\langle y_i,y_j\rangle$ and a pre-layer-norm setup, proving that the critical scaling occurs at $β_n\asymp\log n$ (via $β_n=\gamma\log n$). It characterizes forward dynamics through a phase transition with subcritical, critical, and supercritical regimes around $\gamma=1/(1-\rho)$ (and analogous thresholds in the relaxed setting), showing when attention contractively aggregates tokens, sparsely concentrates interactions, or acts nearly as the identity. The backward pass exhibits a matching transition in gradient propagation, with vanishing gradients in the subcritical regime and nontrivial gradients in the supercritical regime, supported by explicit asymptotics. Numerical experiments on almost-simplex token configurations validate the theoretical predictions, revealing sharp phase transitions in large dimensions and a middle phase in intermediate regimes. Together, these results provide a rigorous justification for logarithmic attention scaling and its role in enabling content-adaptive, sparse interactions while preserving trainability at long context lengths.
Abstract
As large language models scale to longer contexts, attention layers suffer from a fundamental pathology: attention scores collapse toward uniformity as context length $n$ increases, causing tokens to cluster excessively, a phenomenon known as rank-collapse. While $\textit{attention scaling}$ effectively addresses this deficiency by rescaling attention scores with a polylogarithmic factor $β_n$, theoretical justification for this approach remains lacking. We analyze a simplified yet tractable model that magnifies the effect of attention scaling. In this model, attention exhibits a phase transition governed by the scaling factor $β_n$: insufficient scaling collapses all tokens to a single direction, while excessive scaling reduces attention to identity, thereby eliminating meaningful interactions between tokens. Our main result identifies the critical scaling $β_n \asymp \log n$ and provides a rigorous justification for attention scaling in YaRN and Qwen, clarifying why logarithmic scaling maintains sparse, content-adaptive attention at large context lengths.
