Table of Contents
Fetching ...

How Do Nonlinear Transformers Learn and Generalize in In-Context Learning?

Hongkang Li, Meng Wang, Songtao Lu, Xiaodong Cui, Pin-Yu Chen

TL;DR

The paper addresses how to train nonlinear Transformers to perform in-context learning on a family of binary classification tasks under distribution shifts. It introduces a theoretical framework with a single-head, one-layer Transformer using nonlinear self-attention and ReLU MLP, trained via hinge loss on prompts built from context-example pairs, and derives generalization bounds for in-domain and out-of-domain tasks, plus pruning implications. Three core contributions are presented: (i) a quantitative learning analysis with polynomial sample complexity, (ii) a mechanism showing self-attention concentrates on contexts sharing the same IDR/ODR pattern and the MLP reinforces the corresponding label embedding, and (iii) the first theory showing magnitude-based pruning can preserve ICL while reducing inference cost. The findings are supported by numerical experiments that validate the predicted dependencies on prompt length, context-pattern information, and pruning, with practical implications for designing prompt-based learning and efficient Transformers.

Abstract

Transformer-based large language models have displayed impressive in-context learning capabilities, where a pre-trained model can handle new tasks without fine-tuning by simply augmenting the query with some input-output examples from that task. Despite the empirical success, the mechanics of how to train a Transformer to achieve ICL and the corresponding ICL capacity is mostly elusive due to the technical challenges of analyzing the nonconvex training problems resulting from the nonlinear self-attention and nonlinear activation in Transformers. To the best of our knowledge, this paper provides the first theoretical analysis of the training dynamics of Transformers with nonlinear self-attention and nonlinear MLP, together with the ICL generalization capability of the resulting model. Focusing on a group of binary classification tasks, we train Transformers using data from a subset of these tasks and quantify the impact of various factors on the ICL generalization performance on the remaining unseen tasks with and without data distribution shifts. We also analyze how different components in the learned Transformers contribute to the ICL performance. Furthermore, we provide the first theoretical analysis of how model pruning affects ICL performance and prove that proper magnitude-based pruning can have a minimal impact on ICL while reducing inference costs. These theoretical findings are justified through numerical experiments.

How Do Nonlinear Transformers Learn and Generalize in In-Context Learning?

TL;DR

The paper addresses how to train nonlinear Transformers to perform in-context learning on a family of binary classification tasks under distribution shifts. It introduces a theoretical framework with a single-head, one-layer Transformer using nonlinear self-attention and ReLU MLP, trained via hinge loss on prompts built from context-example pairs, and derives generalization bounds for in-domain and out-of-domain tasks, plus pruning implications. Three core contributions are presented: (i) a quantitative learning analysis with polynomial sample complexity, (ii) a mechanism showing self-attention concentrates on contexts sharing the same IDR/ODR pattern and the MLP reinforces the corresponding label embedding, and (iii) the first theory showing magnitude-based pruning can preserve ICL while reducing inference cost. The findings are supported by numerical experiments that validate the predicted dependencies on prompt length, context-pattern information, and pruning, with practical implications for designing prompt-based learning and efficient Transformers.

Abstract

Transformer-based large language models have displayed impressive in-context learning capabilities, where a pre-trained model can handle new tasks without fine-tuning by simply augmenting the query with some input-output examples from that task. Despite the empirical success, the mechanics of how to train a Transformer to achieve ICL and the corresponding ICL capacity is mostly elusive due to the technical challenges of analyzing the nonconvex training problems resulting from the nonlinear self-attention and nonlinear activation in Transformers. To the best of our knowledge, this paper provides the first theoretical analysis of the training dynamics of Transformers with nonlinear self-attention and nonlinear MLP, together with the ICL generalization capability of the resulting model. Focusing on a group of binary classification tasks, we train Transformers using data from a subset of these tasks and quantify the impact of various factors on the ICL generalization performance on the remaining unseen tasks with and without data distribution shifts. We also analyze how different components in the learned Transformers contribute to the ICL performance. Furthermore, we provide the first theoretical analysis of how model pruning affects ICL performance and prove that proper magnitude-based pruning can have a minimal impact on ICL while reducing inference costs. These theoretical findings are justified through numerical experiments.
Paper Structure (45 sections, 14 theorems, 256 equations, 7 figures, 2 tables, 1 algorithm)

This paper contains 45 sections, 14 theorems, 256 equations, 7 figures, 2 tables, 1 algorithm.

Key Result

Theorem 3.3

(In-Domain Generalization) Suppose Condition cond: task holds. For any $\epsilon>0$, when (i) the number of neurons in ${\boldsymbol W}_O$ satisfies $m\geq \Omega(M_1^2\log M_1)$, (ii) batch size $B>\Omega(\max\{\epsilon^{-2},M_1\}\cdot\log M_1)$, (iii) the lengths of training and testing contexts a (iv) and the number of iterations satisfies with step size $\eta\leq 1$ and $N=BT$ samples, then w

Figures (7)

  • Figure 1: (A) Example of prompt embedding. $l=3$, $\alpha=2/3$. (B) The mechanism of a trained Transformer (\ref{['eqn: transformer']}) to implement ICL. Part I: The attention layer assigns the largest attention score (0.8) on ${\boldsymbol \mu}_1-0.3{\boldsymbol \nu}_5$, which has the same IDR pattern as the query. Then the weighted sum of input tokens is close to $({\boldsymbol \mu}_1^\top,{\boldsymbol q}^\top)^\top$ by the trained attention layer. Part II: The neurons in ${\boldsymbol W}_O{\boldsymbol W}_V$ with a large magnitude are aligned with $\bar{{\boldsymbol \mu}}$ and $\pm{\boldsymbol q}$ in the first $d_\mathcal{X}$ and the rest $d_\mathcal{Y}$ dimensions, respectively. Then the prediction is based on the part of $\pm{\boldsymbol q}$ that varies for different queries rather than the part of $\bar{{\boldsymbol \mu}}$ that is universal for all IDR patterns.
  • Figure 2: The properties of the trained model. (A) The average norm of ${\boldsymbol W}_Q{\boldsymbol p}_{query}$, ${\boldsymbol W}_K{\boldsymbol p}_i$, $[XDR({\boldsymbol p}_{query})^\top/\beta,\boldsymbol{0}^\top]\cdot{\boldsymbol W}_Q{\boldsymbol p}_{query}$, and $[XDR({\boldsymbol p}_i)^\top/\beta,\boldsymbol{0}^\top]{\boldsymbol W}_K{\boldsymbol p}_i$. (B) The attention weight summation on contexts with the same ODR pattern as the query and other contexts. (C) The magnitude of the first $d_\mathcal{X}$ dimensions of $5$ neurons in ${\boldsymbol W}_O{\boldsymbol W}_V$ and their angles to $\bar{{\boldsymbol \mu}}$ in $400$ epochs. (D) The magnitude of the rest $d_\mathcal{Y}$ dimensions of $10$ neurons in ${\boldsymbol W}_O{\boldsymbol W}_V$ and their angles to ${\boldsymbol q}$ in $400$ epochs. We choose $5$ neurons for $a_i>0$ and $5$ for $a_i<0$.
  • Figure 3: Out-of-domain ICL classification error on GPT-2 with (a) different $S_1$ on GPT-2 (b) different $\alpha'$ for in-domain (id) and out-of-domain (ood) generalization.
  • Figure 4: Binary classification performance of using ICL, logistic regression (Logistic), SVM with Gaussian kernel (SVM Gau.), SVM with linear kernel (SVM Lin.), 1-nearest neighbor (1-NN), and 3-nearest neighbor (3-NN) with one-layer Transformer when (A) $\alpha'=0.8$ (B) $\alpha'=0.6$.
  • Figure 5: (A) Out-of-domain classification error (left y-axis for curves) with model pruning of the trained ${\boldsymbol W}_O$ using baseline (no pruning), random pruning, and magnitude-based pruning (Mag.-based), and the magnitude of each neuron of ${\boldsymbol W}_O$ (right y-axis for light blue bars) (B) Out-of-domain classification error when varying $\alpha'$. These two are implemented on a one-layer Transformer.
  • ...and 2 more figures

Theorems & Definitions (36)

  • Definition 3.1
  • Theorem 3.3
  • Theorem 3.4
  • Remark 3.5
  • Remark 3.6
  • Theorem 3.7
  • Remark 3.8
  • Proposition 4.1
  • Remark 4.2
  • Corollary 4.3
  • ...and 26 more