Table of Contents
Fetching ...

Asymptotics of SGD in Sequence-Single Index Models and Single-Layer Attention Networks

Luca Arnaboldi, Bruno Loureiro, Ludovic Stephan, Florent Krzakala, Lenka Zdeborova

TL;DR

The paper analyzes SGD dynamics for Sequence Single-Index (SSI) models that generalize single-index learning to sequences with one-layer attention. It introduces the Sequence Information Exponent (SIE) via Hermite expansions and shows the population loss $R(w)$ depends only on the sufficient statistics $(e, m)$, yielding sharp SGD-sample-size scalings: $\mathcal{O}_L(d)$ for $\text{SIE}=1$, $\mathcal{O}_L(d\log^2 d)$ for $\text{SIE}=2$, and $\mathcal{O}_L(d^{\text{SIE}-1})$ for $\text{SIE}\ge 3$, with positional encoding able to reduce the SIE and potentially accelerate learning. The work also contrasts tied (linear attention) and untied networks to quantify a sequence-length–driven speedup, deriving a gain bound that scales with $L$ under favorable structure, and demonstrates a phase diagram where SGD can converge to semantic or positional minima depending on encoding and target structure. These results provide a rigorous, interpretable framework for understanding how sequential structure and positional encoding influence learning with attention-like models, guiding design choices for sequence tasks. Overall, the paper bridges theory for single- and multi-index models with modern sequence-attention architectures, offering precise predictions for sample complexity, convergence rates, and optimization landscapes in high dimensions.

Abstract

We study the dynamics of stochastic gradient descent (SGD) for a class of sequence models termed Sequence Single-Index (SSI) models, where the target depends on a single direction in input space applied to a sequence of tokens. This setting generalizes classical single-index models to the sequential domain, encompassing simplified one-layer attention architectures. We derive a closed-form expression for the population loss in terms of a pair of sufficient statistics capturing semantic and positional alignment, and characterize the induced high-dimensional SGD dynamics for these coordinates. Our analysis reveals two distinct training phases: escape from uninformative initialization and alignment with the target subspace, and demonstrates how the sequence length and positional encoding influence convergence speed and learning trajectories. These results provide a rigorous and interpretable foundation for understanding how sequential structure in data can be beneficial for learning with attention-based models.

Asymptotics of SGD in Sequence-Single Index Models and Single-Layer Attention Networks

TL;DR

The paper analyzes SGD dynamics for Sequence Single-Index (SSI) models that generalize single-index learning to sequences with one-layer attention. It introduces the Sequence Information Exponent (SIE) via Hermite expansions and shows the population loss depends only on the sufficient statistics , yielding sharp SGD-sample-size scalings: for , for , and for , with positional encoding able to reduce the SIE and potentially accelerate learning. The work also contrasts tied (linear attention) and untied networks to quantify a sequence-length–driven speedup, deriving a gain bound that scales with under favorable structure, and demonstrates a phase diagram where SGD can converge to semantic or positional minima depending on encoding and target structure. These results provide a rigorous, interpretable framework for understanding how sequential structure and positional encoding influence learning with attention-like models, guiding design choices for sequence tasks. Overall, the paper bridges theory for single- and multi-index models with modern sequence-attention architectures, offering precise predictions for sample complexity, convergence rates, and optimization landscapes in high dimensions.

Abstract

We study the dynamics of stochastic gradient descent (SGD) for a class of sequence models termed Sequence Single-Index (SSI) models, where the target depends on a single direction in input space applied to a sequence of tokens. This setting generalizes classical single-index models to the sequential domain, encompassing simplified one-layer attention architectures. We derive a closed-form expression for the population loss in terms of a pair of sufficient statistics capturing semantic and positional alignment, and characterize the induced high-dimensional SGD dynamics for these coordinates. Our analysis reveals two distinct training phases: escape from uninformative initialization and alignment with the target subspace, and demonstrates how the sequence length and positional encoding influence convergence speed and learning trajectories. These results provide a rigorous and interpretable foundation for understanding how sequential structure in data can be beneficial for learning with attention-based models.

Paper Structure

This paper contains 43 sections, 11 theorems, 119 equations, 13 figures.

Key Result

Theorem 1

Let $\boldsymbol{f}^\mathrm{SSI}_{\boldsymbol{w}_{\star}}(X)$ be a sequence single-index model, and let $\text{SIE}$ be its sequence information exponent. If the model $f_{\boldsymbol{w}}$ has a rich enough Hermite expansion, then the sample complexity of the SGD algorithm is

Figures (13)

  • Figure 1: the landscape of the population risk, together with the hessian at initialization, for different values of the SIE and positional encoding. (left) $g(\boldsymbol{z}_\star) = \mathrm{He}_2(\boldsymbol{z}_{\star,1}) + \mathrm{He}_2(\boldsymbol{z}_{\star,2})$: SIE=2, no positional encoding: null gradient, but non-null hessian; (center-left) SIE=4, no positional encoding: the first non-null term at initialization is at the 4th order; (center-right) $g(\boldsymbol{z}_\star) = \mathrm{He}_4(\boldsymbol{z}_{\star,1}) + \mathrm{He}_4(\boldsymbol{z}_{\star,2})$: SIE=2, with positional encoding: again dynamic dominated by the hessian, but we have a positive curvature in the direction of $e$; (right) SIE=4, with positional encoding: hessian is positive semidefinite, and the dynamic is again at 4th order in direction of $e$. In all the examples $L=2, P_1=-P_2, R = \mathop{\mathrm{Tr}}\nolimits$.
  • Figure 2: Population loss landscape for $P=0$ (left) and $P\neq 0$ (right). Example of a case where ${\rm SIE}=4$, while ${\rm SIE}_\mathrm{positional}=2$. Target: $g(\boldsymbol{z}_\star) = 4/3 + \mathrm{He}_4(z_1) + 2\mathrm{He}_4(z_{\star,2})$, $P_1=-P_2$, $R = \mathop{\mathrm{Tr}}\nolimits$.
  • Figure 3: Left: overlap $m$ for tied (green) and untied (orange) networks as a function of the number of gradient steps; different symbols represent different values of $L$. Right: measured gain as a function of the sequence length $L$, with the best fit line showing its scaling as $L^2$. $g(\boldsymbol{z}_\star)= \sum_{i=1}^L \mathrm{He}_2(z_{\star,i})$, $d=1000$, $\sigma = \mathop{\mathrm{ReLU}}\nolimits$.
  • Figure 4: Different behaviors of SGD depending on the parameters $\omega$ and $a$.
  • Figure 5: (left) surface of the population loss for $\omega = 0.67$ and $a=1$. The steepest direction at initialization (green vector) points towards the positional local minimum, while the global minima is semantic. Some examples of SGD trajectories are shown in yellow: most of them fall into the semantic local minimum, while some others manage to fully-recover the global minimum due to finite size effects ($d=1000$). (right) empirical probability of convergence to the semantic minima as a function of $\omega$ for $a=1$. The probability is computed over 64 SGD runs with different initializations and data samples. The theoretical prediction of the transition from semantic to positional minima is at $\omega_\mathrm{trans}\approx0.64$.
  • ...and 8 more figures

Theorems & Definitions (16)

  • Definition 1: Weak recovery
  • Definition 2: Sequence Information Exponent (SIE)
  • Theorem 1: Informal
  • Lemma 1
  • Theorem 2
  • Lemma 2
  • Theorem 3
  • Lemma 3
  • Lemma 4
  • Theorem 4
  • ...and 6 more