Table of Contents
Fetching ...

Fundamental limits of learning in sequence multi-index models and deep attention networks: High-dimensional asymptotics and sharp thresholds

Emanuele Troiani, Hugo Cui, Yatin Dandi, Florent Krzakala, Lenka Zdeborová

TL;DR

The paper establishes a theoretical framework for understanding learning in deep attention networks by mapping them to sequence multi-index (SMI) models, enabling high-dimensional asymptotics. It derives Bayes-optimal and AMP (via state evolution) performance in the large-D regime with finite sequence length, uncovering sharp weak-recovery thresholds and showing that layer weights are learned sequentially as data or iterations increase. Through analyses of two-layer and deeper architectures, it demonstrates layer-wise learning dynamics and depth effects, including a grand staircase mechanism for progressive subspace recovery. The work provides a principled basis for predicting learnability and learning dynamics in transformer-like models and suggests directions for extending the theory to more realistic architectures and input statistics.

Abstract

In this manuscript, we study the learning of deep attention neural networks, defined as the composition of multiple self-attention layers, with tied and low-rank weights. We first establish a mapping of such models to sequence multi-index models, a generalization of the widely studied multi-index model to sequential covariates, for which we establish a number of general results. In the context of Bayesian-optimal learning, in the limit of large dimension $D$ and commensurably large number of samples $N$, we derive a sharp asymptotic characterization of the optimal performance as well as the performance of the best-known polynomial-time algorithm for this setting --namely approximate message-passing--, and characterize sharp thresholds on the minimal sample complexity required for better-than-random prediction performance. Our analysis uncovers, in particular, how the different layers are learned sequentially. Finally, we discuss how this sequential learning can also be observed in a realistic setup.

Fundamental limits of learning in sequence multi-index models and deep attention networks: High-dimensional asymptotics and sharp thresholds

TL;DR

The paper establishes a theoretical framework for understanding learning in deep attention networks by mapping them to sequence multi-index (SMI) models, enabling high-dimensional asymptotics. It derives Bayes-optimal and AMP (via state evolution) performance in the large-D regime with finite sequence length, uncovering sharp weak-recovery thresholds and showing that layer weights are learned sequentially as data or iterations increase. Through analyses of two-layer and deeper architectures, it demonstrates layer-wise learning dynamics and depth effects, including a grand staircase mechanism for progressive subspace recovery. The work provides a principled basis for predicting learnability and learning dynamics in transformer-like models and suggests directions for extending the theory to more realistic architectures and input statistics.

Abstract

In this manuscript, we study the learning of deep attention neural networks, defined as the composition of multiple self-attention layers, with tied and low-rank weights. We first establish a mapping of such models to sequence multi-index models, a generalization of the widely studied multi-index model to sequential covariates, for which we establish a number of general results. In the context of Bayesian-optimal learning, in the limit of large dimension and commensurably large number of samples , we derive a sharp asymptotic characterization of the optimal performance as well as the performance of the best-known polynomial-time algorithm for this setting --namely approximate message-passing--, and characterize sharp thresholds on the minimal sample complexity required for better-than-random prediction performance. Our analysis uncovers, in particular, how the different layers are learned sequentially. Finally, we discuss how this sequential learning can also be observed in a realistic setup.

Paper Structure

This paper contains 44 sections, 4 theorems, 166 equations, 5 figures, 1 algorithm.

Key Result

Theorem 2.1

Consider the SMI model (eq:SMI) with non-linearity $g$. Let $\xi \in \mathbb{R}^{P \times M}$ denote a matrix with entries $\xi_{ij} \stackrel{\text{i.i.d.}}{\sim} \mathcal{N}(0,1)$. Let $H_{Y}(Q)$, denote the conditional entropy of the associated output channel $Y$ defined by Eq. eq:effective wit admits a unique global extremizer $\hat{Q}^\star, Q^\star$. Then the asymptotic Bayes-optimal predi

Figures (5)

  • Figure 1: (Left) Diagonal elements of the overlap $Q$\ref{['eq:overlap']} (red and blue) and prediction error (black) achieved by GAMP, as a function of the sample complexity $\alpha=N/D$, for a two-layer attention model $P_1=P_2=1, c=1, M=2$. Off-diagonal elements of $Q$ are zero, and thus not plotted. Crosses: numerical implementations of GAMP in dimension $D=1000$, averaged over $16$ runs. Continuous lines: theoretical prediction from Eqs. \ref{['eq:SE']} and \ref{['eq:generalisation_error']}. Dashed line: prediction error of GAMP when the first layer weights are fixed to zero. (Right) Weak recovery thresholds, delineating the different stages of the learning, as a function of the skip connection strength $c$. The red (resp. blue) lines indicate the sample complexities above which the second (resp. first) layers can be learned.
  • Figure 2: (Top) Evolution of the performance of GAMP with the number of iteration, as measured by the cosine similarity $\mathcal{C}_l =|(\boldsymbol{w}_l^\star) \hat{\boldsymbol{w}}_l^\top| / \|\hat{\boldsymbol{w}}_l\|$ between the GAMP estimate $\hat{\boldsymbol{w}}_l$ of the $l-$th layer weights and the target weights $\boldsymbol{w}^\star$, for $L=2,P_1=P_2=1, c=1,M=2$. We display a single run of the algorithm in dimension $D=1000$ and sample complexity $\alpha = N/D = 1.2$. (Bottom) Evolution of the cosine similarity, for the same target function, when training the same model using SGD. We display $8$ runs of the algorithm in dimension $D=500$ and sample complexity $\alpha=15$, with the average indicated in bold. The numerical experiments were performed at $\lambda= 1.4\times 10^{-4},\eta=15$, and batch size $200$, with each batch used for $3$ consecutive iterations.
  • Figure 3: ( Left) Diagonal elements of the overlap $Q$\ref{['eq:overlap']} (red, blue and green) and prediction error (black) achieved by GAMP, as a function of the sample complexity $\alpha=N/D$, for a three-layer attention model $P_1=P_2=P_3=1, c=1, M=2$. Off-diagonal elements of $Q$ are zero, and thus not plotted. Dashed line: prediction error when the first and second layer weights are fixed to zero. ( Right) Similarity $\mathcal{S}_l = \tr((\boldsymbol{w}_l^{t\top}\boldsymbol{w}_l^{t })(\boldsymbol{w}_l^{* \top} \boldsymbol{w}_l*))/\lVert \boldsymbol{w}_l^t\lVert^2\lVert \boldsymbol{w}_l^*\lVert^2$ between the weights $\boldsymbol{w}_l^t$ at training step $t$ and the last-iterate weights $\boldsymbol{w}^*_l$ as a function of the training time. A transformer with $L=2$ attention layers and a fully connected readout is trained on the TREC classification task hovy-etal-2001-towardli-roth-2002-learning. Different colors indicate different attention layers, with the multiple curves representing distinct runs.
  • Figure 4: Factor graph representation of a sequence multi-index model with $4$ variable nodes (circles) and $3 + 4$ factor nodes (squares). Picture inspired by Aubin2018
  • Figure 5: Left: Overlap of the first layer $Q_{11}$ in two layers of rank one attention with skip connection $c=1$ and sample complexity $\alpha=1$ as a function of the overlap of the second layer $Q_{22}$, kept fixed during the iteration. We can see that unless $Q_{22}$ is almost one, $Q_{11}$ is not learned.

Theorems & Definitions (6)

  • Theorem 2.1
  • Lemma 2.2: State evolution Gerbelot
  • Theorem 2.3
  • Definition 1
  • Definition 2
  • Corollary E.1