Table of Contents
Fetching ...

Bayes optimal learning of attention-indexed models

Fabrizio Boncoraglio, Emanuele Troiani, Vittorio Erba, Lenka Zdeborová

TL;DR

The attention-indexed model (AIM), a theoretical framework for analyzing learning in deep attention layers, is introduced and a matching approximate message passing algorithm is proposed and it is shown that gradient descent can reach optimal performance.

Abstract

We introduce the attention-indexed model (AIM), a theoretical framework for analyzing learning in deep attention layers. Inspired by multi-index models, AIM captures how token-level outputs emerge from layered bilinear interactions over high-dimensional embeddings. Unlike prior tractable attention models, AIM allows full-width key and query matrices, aligning more closely with practical transformers. Using tools from statistical mechanics and random matrix theory, we derive closed-form predictions for Bayes-optimal generalization error and identify sharp phase transitions as a function of sample complexity, model width, and sequence length. We propose a matching approximate message passing algorithm and show that gradient descent can reach optimal performance. AIM offers a solvable playground for understanding learning in self-attention layers, that are key components of modern architectures.

Bayes optimal learning of attention-indexed models

TL;DR

The attention-indexed model (AIM), a theoretical framework for analyzing learning in deep attention layers, is introduced and a matching approximate message passing algorithm is proposed and it is shown that gradient descent can reach optimal performance.

Abstract

We introduce the attention-indexed model (AIM), a theoretical framework for analyzing learning in deep attention layers. Inspired by multi-index models, AIM captures how token-level outputs emerge from layered bilinear interactions over high-dimensional embeddings. Unlike prior tractable attention models, AIM allows full-width key and query matrices, aligning more closely with practical transformers. Using tools from statistical mechanics and random matrix theory, we derive closed-form predictions for Bayes-optimal generalization error and identify sharp phase transitions as a function of sample complexity, model width, and sequence length. We propose a matching approximate message passing algorithm and show that gradient descent can reach optimal performance. AIM offers a solvable playground for understanding learning in self-attention layers, that are key components of modern architectures.

Paper Structure

This paper contains 31 sections, 1 theorem, 247 equations, 5 figures, 1 algorithm.

Key Result

Corollary 4.2

Consider the model eq:single_layer with hardmax activation and $T=2$. In the high-dimensional limit eq:limit, the equation for $q$ of Result thm:SE simplifies to $q = [\max(1-t, 0)]^2$ under the rescaling $\alpha = {\bar{\alpha}} \rho$ and ${\hat{q}} = t \rho$. In particular, the BO error is the sam

Figures (5)

  • Figure 1: (Left) The Bayes optimal-error for the single-layer attention-indexed model with $T=2$ tokens and hardmax activation for and several values of the width ratio $\rho$ (Result \ref{['res:hardmax']}). The log-log scale highlights a large $\alpha$ power-law decay of the BO estimation error, strikingly different from the softmax behaviour (see Figure \ref{['fig2']}). We also plot the corresponding errors achieved by the AMP algorithm (dots) at $d=100$, averaged over 16 realizations of the data and teacher weights. Error bars are computed with respect to the mean. We find a good agreement even for such a moderate size. (Right) Focus on the small width Bayes-optimal error case (Corollary \ref{['cor:rank-limits-hard']}) of the same model. We rescale the sample complexity to $\bar{\alpha}=\alpha / \rho$, and we highlight the theoretical prediction of the weak recovery thresholds (gray vertical line).
  • Figure 2: (Left) Illustration of the Bayes-optimal estimation error for the softmax tied-attention model (Result \ref{['thm:SE_one_layer']}), eq. \ref{['eq:single_layer']} for any $0 < \beta < +\infty$ and $T=2$ tokens, and for several values of the attention width ratio $\rho = r/d$. The model reaches zero BO error at finite $\alpha$ depending on $\rho$ (eq. \ref{['eq:softmax_recovery']}). (Right) We show, in black dashed lines the theoretical prediction of the BO estimation error computed for the sample complexity rescaled by the number of tokens $\alpha/(T^2 + T - 2)$ and $\rho = 0.5$. We show the performance of the corresponding AMP algorithm for $T=2,3$ tokens, correctly achieving the BO error. We also compare the BO performance with those of Adam GD and its averaged version AGD with $d=100$. We average each numerical experiment (GD,AGD,AMP) over $16$ realizations of the data and teacher weights. Error bars are the standard deviation on the mean.
  • Figure 3: Illustration of the Bayes-optimal error for the linear output channel baseline in Eq. \ref{['eq:linear_model']}, for $T=2$ tokens and several values of the width ratio $\rho=r/d$. The model reaches zero BO error at finite $\alpha$. The recovery threshold matches perfectly the one find by the simple counting argument in \ref{['eq:counting']}, plotted in short vertical lines.
  • Figure 4: Comparison between the fixed points solutions of the state equations for a softmax output channel in Eq. \ref{['eq:q_L1']} and Eq. \ref{['eq:state_eq_softmax_general']} for $T=4,5$ tokens. We compare the theoretical solution with their corresponding AMP algorithm run over $16$ different realizations and with $d=120$. The error bars in the AMP dots are computed with respect to the mean value.
  • Figure 5: Low width limit of the self-attention model for $L=1$ layer and $T=2$ tokens in Eq. \ref{['eq:single_layer']}. We rescale the sample ratio as $\bar{\alpha}=n/dr$ and we plot several values of the width ratio $\rho=r/d$. We correctly predict the weak recovery threshold in Eq. \ref{['eq:weak_rec_threshold']}.

Theorems & Definitions (1)

  • Corollary 4.2: Small width limit for hardmax activation