Table of Contents
Fetching ...

In-Context Learning with Representations: Contextual Generalization of Trained Transformers

Tong Yang, Yu Huang, Yingbin Liang, Yuejie Chi

TL;DR

This work provides a first-principles analysis of how transformers can acquire contextual (template) information from prompts to generalize to unseen examples and tasks without parameter updates. By modeling ICL as learning a linear-combination of fixed basis representations and training a one-layer, multi-head transformer with gradient descent, the authors prove linear convergence and reveal that inference implements ridge regression over the basis functions. The results show that contextual generalization emerges even with few, noisy prompt labels, and that the representation dimension relative to the prompt size critically shapes performance. The analysis relaxes several strong assumptions in prior theory and underscores the essential role of multi-head attention in enabling this contextual memorization and generalization. The accompanying experiments on synthetic data corroborate the theory and illustrate practical implications for prompt design and architecture choice.

Abstract

In-context learning (ICL) refers to a remarkable capability of pretrained large language models, which can learn a new task given a few examples during inference. However, theoretical understanding of ICL is largely under-explored, particularly whether transformers can be trained to generalize to unseen examples in a prompt, which will require the model to acquire contextual knowledge of the prompt for generalization. This paper investigates the training dynamics of transformers by gradient descent through the lens of non-linear regression tasks. The contextual generalization here can be attained via learning the template function for each task in-context, where all template functions lie in a linear space with $m$ basis functions. We analyze the training dynamics of one-layer multi-head transformers to in-contextly predict unlabeled inputs given partially labeled prompts, where the labels contain Gaussian noise and the number of examples in each prompt are not sufficient to determine the template. Under mild assumptions, we show that the training loss for a one-layer multi-head transformer converges linearly to a global minimum. Moreover, the transformer effectively learns to perform ridge regression over the basis functions. To our knowledge, this study is the first provable demonstration that transformers can learn contextual (i.e., template) information to generalize to both unseen examples and tasks when prompts contain only a small number of query-answer pairs.

In-Context Learning with Representations: Contextual Generalization of Trained Transformers

TL;DR

This work provides a first-principles analysis of how transformers can acquire contextual (template) information from prompts to generalize to unseen examples and tasks without parameter updates. By modeling ICL as learning a linear-combination of fixed basis representations and training a one-layer, multi-head transformer with gradient descent, the authors prove linear convergence and reveal that inference implements ridge regression over the basis functions. The results show that contextual generalization emerges even with few, noisy prompt labels, and that the representation dimension relative to the prompt size critically shapes performance. The analysis relaxes several strong assumptions in prior theory and underscores the essential role of multi-head attention in enabling this contextual memorization and generalization. The accompanying experiments on synthetic data corroborate the theory and illustrate practical implications for prompt design and architecture choice.

Abstract

In-context learning (ICL) refers to a remarkable capability of pretrained large language models, which can learn a new task given a few examples during inference. However, theoretical understanding of ICL is largely under-explored, particularly whether transformers can be trained to generalize to unseen examples in a prompt, which will require the model to acquire contextual knowledge of the prompt for generalization. This paper investigates the training dynamics of transformers by gradient descent through the lens of non-linear regression tasks. The contextual generalization here can be attained via learning the template function for each task in-context, where all template functions lie in a linear space with basis functions. We analyze the training dynamics of one-layer multi-head transformers to in-contextly predict unlabeled inputs given partially labeled prompts, where the labels contain Gaussian noise and the number of examples in each prompt are not sufficient to determine the template. Under mild assumptions, we show that the training loss for a one-layer multi-head transformer converges linearly to a global minimum. Moreover, the transformer effectively learns to perform ridge regression over the basis functions. To our knowledge, this study is the first provable demonstration that transformers can learn contextual (i.e., template) information to generalize to both unseen examples and tasks when prompts contain only a small number of query-answer pairs.
Paper Structure (44 sections, 13 theorems, 106 equations, 5 figures, 2 tables)

This paper contains 44 sections, 13 theorems, 106 equations, 5 figures, 2 tables.

Key Result

Proposition 1

Suppose Assumptions asmp:lambda_dist, asmp:init hold and $H\geq N$. For any fixed $\beta>0$, let ${\bm{Q}}_h^{(0)}(i,j)\overset{i.i.d.}{\sim}\mathcal{N}(0,\beta^2)$, then Assumption asmp:gamma holds almost surely.

Figures (5)

  • Figure 1: The structure of a one-layer transformer with multi-head softmax attention.
  • Figure 2: Training and inference losses of (a) 1-layer and (b) 4-layer transformers, which validate Theorem \ref{['thm:inference_under']}, as well as the transformer's contextual generalization to unseen examples and to unseen tasks.
  • Figure 3: The performance gap $\frac{1}{K}\left\| \widehat{{\bm{y}}}^\star-\widehat{{\bm{y}}}^{\text{best}} \right\|_2^2$ with different $N$ when $m=100$, which validates that the closer $N$ is to $m$, the better the transformer's prediction is.
  • Figure 4: Training losses of the 1-layer transformer with different number of attention heads $H$, where $H$ should be large enough to guarantee the convergence of the training loss, but setting $H$ too large leads to instability and slower divergence.
  • Figure 5: Training losses of a 4-layer transformer with different $H$, fixing wall-clock time to be $100$s. This experiment shows that unlike 1-layer transformers, deeper transformers don't require $H$ to be large to guarantee convergence of the loss.

Theorems & Definitions (25)

  • Proposition 1: Initialization of $\{{\bm{Q}}_h\}_{h=1}^H$
  • proof
  • Theorem 1: Training time convergence
  • proof
  • Theorem 2: Inference time performance
  • proof
  • Lemma 1: Softmax gradient
  • proof
  • Lemma 2: Smoothness of softmax
  • proof
  • ...and 15 more