Table of Contents
Fetching ...

In-Context Convergence of Transformers

Yu Huang, Yuan Cheng, Yingbin Liang

TL;DR

This work tackles the theoretical understanding of in-context learning with softmax-attention transformers by analyzing a one-layer model trained via gradient descent on linear-function tasks. It introduces a structured data framework with balanced and imbalanced feature distributions and derives convergence results: a finite-time, near-zero error guarantee for balanced features and a four-phase, stage-wise convergence for imbalanced ones, driven by the evolving bilinear attention weights $A_k^{(t)}$ and $B_{k,n}^{(t)}$. The study shows that GD-guided softmax attention can concentrate attention on the relevant feature and predict unseen prompts accurately, thereby shedding light on the in-context learning mechanism in nonlinear transformers. The results have implications for understanding when and how in-context knowledge can be learned and transferred in practical transformer architectures, and they introduce analytic techniques that may extend to broader models and tasks.

Abstract

Transformers have recently revolutionized many domains in modern machine learning and one salient discovery is their remarkable in-context learning capability, where models can solve an unseen task by utilizing task-specific prompts without further parameters fine-tuning. This also inspired recent theoretical studies aiming to understand the in-context learning mechanism of transformers, which however focused only on linear transformers. In this work, we take the first step toward studying the learning dynamics of a one-layer transformer with softmax attention trained via gradient descent in order to in-context learn linear function classes. We consider a structured data model, where each token is randomly sampled from a set of feature vectors in either balanced or imbalanced fashion. For data with balanced features, we establish the finite-time convergence guarantee with near-zero prediction error by navigating our analysis over two phases of the training dynamics of the attention map. More notably, for data with imbalanced features, we show that the learning dynamics take a stage-wise convergence process, where the transformer first converges to a near-zero prediction error for the query tokens of dominant features, and then converges later to a near-zero prediction error for the query tokens of under-represented features, respectively via one and four training phases. Our proof features new techniques for analyzing the competing strengths of two types of attention weights, the change of which determines different training phases.

In-Context Convergence of Transformers

TL;DR

This work tackles the theoretical understanding of in-context learning with softmax-attention transformers by analyzing a one-layer model trained via gradient descent on linear-function tasks. It introduces a structured data framework with balanced and imbalanced feature distributions and derives convergence results: a finite-time, near-zero error guarantee for balanced features and a four-phase, stage-wise convergence for imbalanced ones, driven by the evolving bilinear attention weights and . The study shows that GD-guided softmax attention can concentrate attention on the relevant feature and predict unseen prompts accurately, thereby shedding light on the in-context learning mechanism in nonlinear transformers. The results have implications for understanding when and how in-context knowledge can be learned and transferred in practical transformer architectures, and they introduce analytic techniques that may extend to broader models and tasks.

Abstract

Transformers have recently revolutionized many domains in modern machine learning and one salient discovery is their remarkable in-context learning capability, where models can solve an unseen task by utilizing task-specific prompts without further parameters fine-tuning. This also inspired recent theoretical studies aiming to understand the in-context learning mechanism of transformers, which however focused only on linear transformers. In this work, we take the first step toward studying the learning dynamics of a one-layer transformer with softmax attention trained via gradient descent in order to in-context learn linear function classes. We consider a structured data model, where each token is randomly sampled from a set of feature vectors in either balanced or imbalanced fashion. For data with balanced features, we establish the finite-time convergence guarantee with near-zero prediction error by navigating our analysis over two phases of the training dynamics of the attention map. More notably, for data with imbalanced features, we show that the learning dynamics take a stage-wise convergence process, where the transformer first converges to a near-zero prediction error for the query tokens of dominant features, and then converges later to a near-zero prediction error for the query tokens of under-represented features, respectively via one and four training phases. Our proof features new techniques for analyzing the competing strengths of two types of attention weights, the change of which determines different training phases.
Paper Structure (77 sections, 68 theorems, 260 equations, 1 figure, 1 table)

This paper contains 77 sections, 68 theorems, 260 equations, 1 figure, 1 table.

Key Result

Theorem 3.1

Suppose $p_k=\Theta\left(\frac{1}{K}\right)$ for $k \in [K]$. For any $0<\epsilon<1$, suppose $N\geq \mathrm{poly}(K)$ and $\mathrm{polylog}(K)\gg \log(\frac{1}{\epsilon})$. We apply GD to train the loss function given in eq:obj. Then with at most $T^*=O(\frac{\log(K)K^2}{\eta}+\frac{K \log\left(K\e

Figures (1)

  • Figure 1: Overview of the dynamics of attention scores and bilinear attention weights for under-represented features. Assume the query token is $v_k$ with $2\leq k\leq K$. The top row depicts the trend of the attention score $\operatorname{\bf Attn}^{(t)}_m$ for each feature $v_m$, where a darker color corresponds to a higher score. The bottom row shows the interplay and leading effect among bilinear attention weights $A^{(t)}_k,B^{(t)}_{k,1}$, and $B^{(t)}_{k,n}$ (where $n \neq 1,k$) in different training phases. (a) Phase I: $B^{(t)}_{k,1}$ significantly decreases and the attention on tokens with the dominant feature $v_1$ is suppressed (\ref{['sssec: 4-1-p1']}); (b) Phase II: With the suppression of $\operatorname{\bf Attn}_{1}^{(t)}$, the decreasing rate for $B^{(t)}_{k,1}$ drops and the growth of $A_{k}^{(t)}$ becomes the leading influence (\ref{['sssec: 4-1-p2']}); (c) Phase III: $A^{(t)}_k$ rapidly grows and $\operatorname{\bf Attn}^{(t)}_k$ reaches $\Omega(1)$ (\ref{['sssec: 4-1-p3']}); (d) Phase IV: $\operatorname{\bf Attn}^{(t)}_k$ nearly grows to $1$ and the prediction error converges to a global minimum (\ref{['sssec: 4-1-p4']}).

Theorems & Definitions (130)

  • Definition 2.1: Self-Attention (SA) Mechanism
  • Remark 1: Nealy no loss of optimality
  • Definition 3.1: Attention Score
  • Theorem 3.1: In-context Learning with Balanced Features
  • Theorem 3.2: In-context Learning with Imbalanced Features
  • Definition 4.1
  • Lemma 4.1
  • Lemma 4.2: Informal
  • Lemma 4.3: Informal
  • Lemma 4.4: Informal
  • ...and 120 more