Fundamental limits of learning in sequence multi-index models and deep attention networks: High-dimensional asymptotics and sharp thresholds
Emanuele Troiani, Hugo Cui, Yatin Dandi, Florent Krzakala, Lenka Zdeborová
TL;DR
The paper establishes a theoretical framework for understanding learning in deep attention networks by mapping them to sequence multi-index (SMI) models, enabling high-dimensional asymptotics. It derives Bayes-optimal and AMP (via state evolution) performance in the large-D regime with finite sequence length, uncovering sharp weak-recovery thresholds and showing that layer weights are learned sequentially as data or iterations increase. Through analyses of two-layer and deeper architectures, it demonstrates layer-wise learning dynamics and depth effects, including a grand staircase mechanism for progressive subspace recovery. The work provides a principled basis for predicting learnability and learning dynamics in transformer-like models and suggests directions for extending the theory to more realistic architectures and input statistics.
Abstract
In this manuscript, we study the learning of deep attention neural networks, defined as the composition of multiple self-attention layers, with tied and low-rank weights. We first establish a mapping of such models to sequence multi-index models, a generalization of the widely studied multi-index model to sequential covariates, for which we establish a number of general results. In the context of Bayesian-optimal learning, in the limit of large dimension $D$ and commensurably large number of samples $N$, we derive a sharp asymptotic characterization of the optimal performance as well as the performance of the best-known polynomial-time algorithm for this setting --namely approximate message-passing--, and characterize sharp thresholds on the minimal sample complexity required for better-than-random prediction performance. Our analysis uncovers, in particular, how the different layers are learned sequentially. Finally, we discuss how this sequential learning can also be observed in a realistic setup.
