Table of Contents
Fetching ...

Specialization of softmax attention heads: insights from the high-dimensional single-location model

M. Sagitova, O. Duranthon, L. Zdeborová

TL;DR

This work analyzes the training dynamics of multi-head softmax attention under SGD, revealing an initial unspecialized phase followed by a multi-stage specialization phase in which different heads sequentially align with latent signal directions.

Abstract

Multi-head attention enables transformer models to represent multiple attention patterns simultaneously. Empirically, head specialization emerges in distinct stages during training, while many heads remain redundant and learn similar representations. We propose a theoretical model capturing this phenomenon, based on the multi-index and single-location regression frameworks. In the first part, we analyze the training dynamics of multi-head softmax attention under SGD, revealing an initial unspecialized phase followed by a multi-stage specialization phase in which different heads sequentially align with latent signal directions. In the second part, we study the impact of attention activation functions on performance. We show that softmax-1 significantly reduces noise from irrelevant heads. Finally, we introduce the Bayes-softmax attention, which achieves optimal prediction performance in this setting.

Specialization of softmax attention heads: insights from the high-dimensional single-location model

TL;DR

This work analyzes the training dynamics of multi-head softmax attention under SGD, revealing an initial unspecialized phase followed by a multi-stage specialization phase in which different heads sequentially align with latent signal directions.

Abstract

Multi-head attention enables transformer models to represent multiple attention patterns simultaneously. Empirically, head specialization emerges in distinct stages during training, while many heads remain redundant and learn similar representations. We propose a theoretical model capturing this phenomenon, based on the multi-index and single-location regression frameworks. In the first part, we analyze the training dynamics of multi-head softmax attention under SGD, revealing an initial unspecialized phase followed by a multi-stage specialization phase in which different heads sequentially align with latent signal directions. In the second part, we study the impact of attention activation functions on performance. We show that softmax-1 significantly reduces noise from irrelevant heads. Finally, we introduce the Bayes-softmax attention, which achieves optimal prediction performance in this setting.
Paper Structure (27 sections, 15 theorems, 54 equations, 20 figures)

This paper contains 27 sections, 15 theorems, 54 equations, 20 figures.

Key Result

Proposition 3.1

The loss of the attention $(k,b,v)\mapsto \mathcal{E}_\sigma(k,b,v)$ can be reparametrized as a function $(m,r,b,v)\in(\mathbb R^{H\times F},\mathcal{S}_+^H,\mathbb R^H,\mathbb R)\mapsto \tilde{\mathcal{E}}_\sigma(m,r,b,v)$ of the following order parameters, for $h,h'\in [H]$, $f,f'\in [F]$: Let $\epsilon\sim\mathop{\mathrm{Unif}}\nolimits(\{1,\ldots,L\})$, $\theta\sim P_\theta$ and conditionally

Figures (20)

  • Figure 1: Asymptotic description of the attention trained by SGD. We compare numerical simulations at finite $D=10^4$ (dots) and the theoretical description stated in proposition \ref{['res:sgd']} (continuous lines). We consider sequence length $L=10$, $H=2$ heads with $\sigma$ softmax attention, $F=2$ features, $\theta$ drawn from the flipping spike distribution, with signal strengths $\nu_1=\nu_2=2$. $f=1$ is the "constant" direction while $f=2$ is the "flipping sign" direction. Initialization $\eta=1$.
  • Figure 2: Evolution of the $H$ heads for the flipping spike distribution at $F=2$, in the span of the "constant" direction $k^*_1$ and the "flipping sign" direction $k^*_2$. $\sigma$ softmax, $L=4$, $\nu_1=\nu_2=2$ and $\eta=1$.
  • Figure 3: Evolution of the heads for the non-isotropic Gaussian distribution at $F=2$. $\sigma$ softmax, $L=4$, $\eta=1$. The two central panels at $H=8$ correspond to two different runs with different realizations of the initial condition.
  • Figure 4: Evolution of the $H=8$ heads for the non-isotropic Gaussian distribution at $F=3$, according to Prop. \ref{['res:sgd']}. Left: $\sigma$ softmax; right: $\sigma$ B-softmax. $L=5$, $\nu_1=20$, $\nu_2=1$ and $\eta=1$.
  • Figure 5: Predicted error $\mathcal{E}^\infty_\sigma$ of the different activation functions after training. $L=5$, $\eta=1$. Left: $F=4$, $\nu=10$, varying $H$; right: $F=2$, $H=4$, varying signal strength. We performed 5 runs with different initial conditions; the standard deviations are not visible.
  • ...and 15 more figures

Theorems & Definitions (24)

  • Proposition 3.1: Reparameterized loss
  • Proposition 3.2: Effective dynamics
  • Proposition 3.3: Unspecialized phase
  • Proposition 3.4: Specialization phase
  • Lemma 3.1: Hessian before specialization
  • Lemma 3.2: Hessian during specialization
  • Proposition 4.1: Bayes estimator
  • Proposition 4.2: Optimality of the Bayes-softmax attention
  • Proposition 4.3: Expressivity of softmax and softmax-1
  • proof : Proof of Proposition \ref{['res:paramLoss']}
  • ...and 14 more