Table of Contents
Fetching ...

Single-Head Attention in High Dimensions: A Theory of Generalization, Weights Spectra, and Scaling Laws

Fabrizio Boncoraglio, Vittorio Erba, Emanuele Troiani, Yizhou Xu, Florent Krzakala, Lenka Zdeborová

TL;DR

The paper develops a high‑dimensional theory of single‑head tied attention trained by ERM, linking learned weight spectra to generalization. It maps ERM in attention to a generalized matrix sensing problem and solves it with approximate message passing to obtain exact asymptotics for training/test errors, interpolation/recovery thresholds, and the spectrum of the learned query–key map. The spectrum comprises a structured bulk plus spectral outliers that encode learned features, providing a direct quantitative bridge between spectral structure and generalization. The work also uncovers power‑law scaling laws for targets with heavy‑tailed spectra, showing sequential spectral recovery and universal exponents, thereby offering a principled explanation for emergence and scaling phenomena observed in transformers. Although based on simplifying assumptions, the results reproduce key qualitative phenomena and pave the way for extensions to more realistic data distributions and architectures.

Abstract

Trained attention layers exhibit striking and reproducible spectral structure of the weights, including low-rank collapse, bulk deformation, and isolated spectral outliers, yet the origin of these phenomena and their implications for generalization remain poorly understood. We study empirical risk minimization in a single-head tied-attention layer trained on synthetic high-dimensional sequence tasks generated from the attention-indexed model. Using tools from random matrix theory, spin-glass theory, and approximate message passing, we obtain an exact high-dimensional characterization of training and test error, interpolation and recovery thresholds, and the spectrum of the key and query matrices. Our theory predicts the full singular-value distribution of the trained query-key map, including low-rank structure and isolated spectral outliers, in qualitative agreement with observations in more realistic transformers. Finally, for targets with power-law spectra, we show that learning proceeds through sequential spectral recovery, leading to the emergence of power-law scaling laws.

Single-Head Attention in High Dimensions: A Theory of Generalization, Weights Spectra, and Scaling Laws

TL;DR

The paper develops a high‑dimensional theory of single‑head tied attention trained by ERM, linking learned weight spectra to generalization. It maps ERM in attention to a generalized matrix sensing problem and solves it with approximate message passing to obtain exact asymptotics for training/test errors, interpolation/recovery thresholds, and the spectrum of the learned query–key map. The spectrum comprises a structured bulk plus spectral outliers that encode learned features, providing a direct quantitative bridge between spectral structure and generalization. The work also uncovers power‑law scaling laws for targets with heavy‑tailed spectra, showing sequential spectral recovery and universal exponents, thereby offering a principled explanation for emergence and scaling phenomena observed in transformers. Although based on simplifying assumptions, the results reproduce key qualitative phenomena and pave the way for extensions to more realistic data distributions and architectures.

Abstract

Trained attention layers exhibit striking and reproducible spectral structure of the weights, including low-rank collapse, bulk deformation, and isolated spectral outliers, yet the origin of these phenomena and their implications for generalization remain poorly understood. We study empirical risk minimization in a single-head tied-attention layer trained on synthetic high-dimensional sequence tasks generated from the attention-indexed model. Using tools from random matrix theory, spin-glass theory, and approximate message passing, we obtain an exact high-dimensional characterization of training and test error, interpolation and recovery thresholds, and the spectrum of the key and query matrices. Our theory predicts the full singular-value distribution of the trained query-key map, including low-rank structure and isolated spectral outliers, in qualitative agreement with observations in more realistic transformers. Finally, for targets with power-law spectra, we show that learning proceeds through sequential spectral recovery, leading to the emergence of power-law scaling laws.

Paper Structure

This paper contains 58 sections, 1 theorem, 169 equations, 11 figures.

Key Result

Corollary 3.2

Consider the setting of Section sec:set_contr with $\lambda \to 0^+$, $\Delta \geq 0$ and $T \geq 2$. Then, there exists a value of sample ratio $\alpha_{\rm interp}$ such that the training loss at its global minimum is zero for $\alpha < \alpha_{\rm interp}$ (perfect fit of the training set), and s where $\bar{\delta}$ is the solution of and $J$ is defined in Eq. eq:MJ. Moreover, if $\Delta = 0$

Figures (11)

  • Figure 1: (Left) Test error of the ERM estimator Eq. \ref{['eq:erm']} (Claim \ref{['claim:main']}, Eq. \ref{['eq:text_train']}) compared with Adam simulations at $d=100$ ($64$ instances, error bars = standard deviation) as a function of the number of samples $\alpha=n/d^2$, where we use $\kappa=0.75,1$ (model width) and parameters $\lambda=0.01$, $\Delta=0.5$, $T=2$, $\beta=\beta_0=1$ for the MP target (Section \ref{['sec:learning-curves']}, $\kappa_0=0.5$). Theory and simulations agree. (Right) Singular value spectrum of the trained weights from theory (blue, Claim \ref{['claim:spectrum']}, Eq. \ref{['eq:spectral_law_maintext']}) vs. Adam simulations (grey histograms) at $\alpha=0.05,0.5,4,40$ ($d=200$, $64$ runs, $2000$ samples in the test set). The asymptotic spectrum of the target is shown in red dashed. The theory also captures the delta peak at zero. For large $\alpha$, the spectrum splits into two bulks.
  • Figure 2: Test error as a function of $\alpha$ for standard factorized query--key training Eq. \ref{['eq:erm']} and direct (non-factorized) training of the attention matrix Eq. \ref{['eq:ermfrob']}, both at optimal regularization (selected by cross-validation). The factorized parameterization consistently achieves significantly lower test error across all $\alpha$. Solid lines show theoretical predictions (Claim \ref{['claim:main']} and Appendix \ref{['app:L2']}), while dots correspond to numerical experiments with Adam at $d=100$, averaged over $8$ runs with $2000$ test samples. Parameters are $\kappa_0 = 0.05$, $\Delta = 0.05$, $T=2$, and $\beta=\beta_0=1$.
  • Figure 3: Qualitative representation of the spectral density and error decomposition of Section \ref{['sec:spectrum_generalization']}. (Top) Spectrum of the target (below the horizontal axis) and of the ERM (above the axis). The ERM spectrum is composed of outliers (learned target features) and a noise bulk (magenta). (Bottom) Error decomposition for power-law target $\gamma = 0.75$, $d=200$, $\Delta=0.5$, $T=2$, $\lambda = 1/d$ (see Appendix \ref{['app:more']}), expressed as fraction of the total error (shown in Figure \ref{['fig:figurePL']}).
  • Figure 4: (Left) Excess test error (Eq. Eq. \ref{['eq:test']} minus its value at $n \to +\infty$) of the ERM estimator Eq. \ref{['eq:erm']} (Claim \ref{['claim:main']}, Eq. \ref{['eq:text_train']}) compared with Adam simulations at $d=200$ ($8$ instances, error bars = standard deviation) as a function of sample number $n$, where we use power-law target (Section \ref{['sec:powerlaw']}, decay exponent $\gamma = 0.75$), and $\Delta=0.5$, $T=2$, $\beta=\beta_0=1$. We plot three values of $\lambda=1/d, 1, \sqrt{d}$. One can see clearly several scaling regimes in different regions, with different decay exponents. We highlight those of them (in dashed lines) in which the error decays as $n^{-f(\gamma)}$ with non-trivial dependency on $\gamma$, namely $f(\gamma) = 1-1/(2\gamma)$ for $\lambda = 1/d$, and $f(\gamma) = 2 - 1/\gamma$ for $\lambda=1,\sqrt{d}$. (Right) Eigenvalue spectrum of the ERM estimator from theory (blue, Claim \ref{['claim:spectrum']}) vs. Adam simulations (grey histograms) at $n = 4 \cdot 10^3, 4 \cdot 10^5$ ($d=400$, single runs). For large $n$ the spectrum develops a heavy tail.
  • Figure 5: Excess risk rates of as a function of $n$ and $\lambda(n,d)$, with a sketch of the corresponding spectral properties of the learned weights.
  • ...and 6 more figures

Theorems & Definitions (4)

  • Claim 3.1: Train and test error of ERM
  • Corollary 3.2: Interpolation and perfect recovery thresholds
  • Claim 4.1: Spectra of ERM
  • Conjecture E.1