Table of Contents
Fetching ...

Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape

Juno Kim, Taiji Suzuki

TL;DR

This paper studies the optimization of a Transformer consisting of a fully connected layer followed by a linear attention layer, and proves in the mean-field and two-timescale limit that the infinite-dimensional loss landscape for the distribution of parameters, while highly nonconvex, becomes quite benign.

Abstract

Large language models based on the Transformer architecture have demonstrated impressive capabilities to learn in context. However, existing theoretical studies on how this phenomenon arises are limited to the dynamics of a single layer of attention trained on linear regression tasks. In this paper, we study the optimization of a Transformer consisting of a fully connected layer followed by a linear attention layer. The MLP acts as a common nonlinear representation or feature map, greatly enhancing the power of in-context learning. We prove in the mean-field and two-timescale limit that the infinite-dimensional loss landscape for the distribution of parameters, while highly nonconvex, becomes quite benign. We also analyze the second-order stability of mean-field dynamics and show that Wasserstein gradient flow almost always avoids saddle points. Furthermore, we establish novel methods for obtaining concrete improvement rates both away from and near critical points. This represents the first saddle point analysis of mean-field dynamics in general and the techniques are of independent interest.

Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape

TL;DR

This paper studies the optimization of a Transformer consisting of a fully connected layer followed by a linear attention layer, and proves in the mean-field and two-timescale limit that the infinite-dimensional loss landscape for the distribution of parameters, while highly nonconvex, becomes quite benign.

Abstract

Large language models based on the Transformer architecture have demonstrated impressive capabilities to learn in context. However, existing theoretical studies on how this phenomenon arises are limited to the dynamics of a single layer of attention trained on linear regression tasks. In this paper, we study the optimization of a Transformer consisting of a fully connected layer followed by a linear attention layer. The MLP acts as a common nonlinear representation or feature map, greatly enhancing the power of in-context learning. We prove in the mean-field and two-timescale limit that the infinite-dimensional loss landscape for the distribution of parameters, while highly nonconvex, becomes quite benign. We also analyze the second-order stability of mean-field dynamics and show that Wasserstein gradient flow almost always avoids saddle points. Furthermore, we establish novel methods for obtaining concrete improvement rates both away from and near critical points. This represents the first saddle point analysis of mean-field dynamics in general and the techniques are of independent interest.
Paper Structure (64 sections, 38 theorems, 145 equations, 1 figure, 1 algorithm)

This paper contains 64 sections, 38 theorems, 145 equations, 1 figure, 1 algorithm.

Key Result

Proposition 2.1

Suppose $h_\mu$ includes a bias term, i.e. $\mathop{\mathrm{\mathscr{X}}}\nolimits\subseteq\mathop{\mathrm{\mathscr{X}}}\nolimits_0\times \{1\}$. If $f=(f_j)_{j=1}^k \in C(\mathop{\mathrm{\mathscr{X}}}\nolimits_0,\mathop{\mathrm{\mathbb{R}}}\nolimits^k)$ such that each $f_j$ satisfies $\inf_{\wideha

Figures (1)

  • Figure 1: (a) Training error of the attention, static and modified Transformers. (b) Learning degenerate features with $\mathop{\mathrm{rank}}\nolimits\bm{\Sigma}_{\mu^\circ,\mu^\circ}<k$. (c) Training a misspecified model containing two extra features. (d) Test error for the nonlinear norm task $\lVert h_{\mu^\circ}(\bm{x})\rVert$.

Theorems & Definitions (61)

  • Proposition 2.1
  • Lemma 2.2
  • Proposition 2.3
  • Proposition 2.4
  • Proposition 2.5
  • Lemma 3.1
  • Lemma 3.2
  • Theorem 3.3: no spurious local minima
  • Proposition 3.4: accelerated convergence phase
  • Lemma 4.1
  • ...and 51 more