Table of Contents
Fetching ...

On the Training Convergence of Transformers for In-Context Classification of Gaussian Mixtures

Wei Shen, Ruida Zhou, Jing Yang, Cong Shen

TL;DR

This work provides a rigorous analysis of training dynamics for in-context classification using transformers on Gaussian mixtures. It proves linear convergence of a single-layer transformer under gradient descent to a global minimizer $W^*=2(\Lambda^{-1}+G)$ with $\|G\|_{\max}=O(1/N)$, and it derives finite-prompt and testing-prompt error bounds that decay as $O(1/N)$ and $O(1/\sqrt{M})$, respectively. Extending to multi-class problems, it shows $W^*=c(\Lambda^{-1}+G)$ with $\|G\|_{\max}=O(c/N)$ and establishes an $O(c^2N^{-1}+c^{3/2}M^{-1/2})$ inference-error bound, linking the learned transformer to approximate linear discriminant/softmax classifiers. Experiments corroborate the theory, reveal robustness distinctions between 1-layer and 3-layer models under distributional shifts, and demonstrate superior ICL performance over traditional baselines on Gaussian mixtures, highlighting the practical relevance of prompt-length and class-count effects.

Abstract

Although transformers have demonstrated impressive capabilities for in-context learning (ICL) in practice, theoretical understanding of the underlying mechanism that allows transformers to perform ICL is still in its infancy. This work aims to theoretically study the training dynamics of transformers for in-context classification tasks. We demonstrate that, for in-context classification of Gaussian mixtures under certain assumptions, a single-layer transformer trained via gradient descent converges to a globally optimal model at a linear rate. We further quantify the impact of the training and testing prompt lengths on the ICL inference error of the trained transformer. We show that when the lengths of training and testing prompts are sufficiently large, the prediction of the trained transformer approaches the ground truth distribution of the labels. Experimental results corroborate the theoretical findings.

On the Training Convergence of Transformers for In-Context Classification of Gaussian Mixtures

TL;DR

This work provides a rigorous analysis of training dynamics for in-context classification using transformers on Gaussian mixtures. It proves linear convergence of a single-layer transformer under gradient descent to a global minimizer with , and it derives finite-prompt and testing-prompt error bounds that decay as and , respectively. Extending to multi-class problems, it shows with and establishes an inference-error bound, linking the learned transformer to approximate linear discriminant/softmax classifiers. Experiments corroborate the theory, reveal robustness distinctions between 1-layer and 3-layer models under distributional shifts, and demonstrate superior ICL performance over traditional baselines on Gaussian mixtures, highlighting the practical relevance of prompt-length and class-count effects.

Abstract

Although transformers have demonstrated impressive capabilities for in-context learning (ICL) in practice, theoretical understanding of the underlying mechanism that allows transformers to perform ICL is still in its infancy. This work aims to theoretically study the training dynamics of transformers for in-context classification tasks. We demonstrate that, for in-context classification of Gaussian mixtures under certain assumptions, a single-layer transformer trained via gradient descent converges to a globally optimal model at a linear rate. We further quantify the impact of the training and testing prompt lengths on the ICL inference error of the trained transformer. We show that when the lengths of training and testing prompts are sufficiently large, the prediction of the trained transformer approaches the ground truth distribution of the labels. Experimental results corroborate the theoretical findings.

Paper Structure

This paper contains 35 sections, 20 theorems, 170 equations, 4 figures.

Key Result

Theorem 3.3

Under Assumption assume: training data distribution, binary, the following statements hold.

Figures (4)

  • Figure 1: '1-layer': single-layer transformer defined in Section \ref{['section: multi-class classification']}, '3-layer': 3-layer transformers with softmax attention. $N$: training prompt length. $c$: number of Gaussian mixtures.
  • Figure 2: We generate $\Lambda_i=\text{diag}(\lambda_{i1}, \dots, \lambda_{id})$, $i\in\{0,1,2,3\}$, where $\lambda_{ij}=|\hat{\lambda}_{ij}|$ and $\hat{\lambda}_{ij}\stackrel{\mathrm{ i.i.d.}}{\sim} \mathsf{N}(3,1)$. All models are trained with prompt length $N=100$, tested with prompts satisfying Assumption \ref{['assume: test prompt distribution, multi']} with $\Lambda_0$. $c=3$. (a) 'same norm': pre-training data are sampled according to Assumption \ref{['assume: training data distribution, multi']} with $\Lambda_0$. 'different norms': For each $\tau$, with probability $\mathbb{P}\left(k=j\right)=1/10, \mu_{\tau, i}\sim\mathsf{N}(k, I_d), j=0,1,\ldots,9$. (b) 'same covariance': pre-training data are sampled according to Assumption \ref{['assume: training data distribution, multi']} for the fixed $\Lambda_0$. 'different covariances': Sample additional $\Lambda_1, \Lambda_2, \Lambda_3$. Then, generate pre-training data according to Assumption \ref{['assume: training data distribution, multi']} with $\Lambda_0, \Lambda_1, \Lambda_2, \Lambda_3$.
  • Figure 3: '1-layer, sparse': single-layer transformer defined in Section \ref{['section: multi-class classification']}, '1-layer, full': single-layer transformer with full parameters \ref{['full para, multi']}, '3-layer': a 3-layer transformer with softmax attention, 'softmax': softmax regression, 'LDA': linear discriminant analysis. All three transformers are trained with prompt length $N=100$.
  • Figure 4: Inference errors of single-layer transformers. (a): Models trained on different training prompt lengths $N$ on classification tasks involving $c=10$ classes. (b): Models trained on different classification tasks involving $c$ classes with a fixed training prompt length $N=80$. (c): Relationship between the inference error and the test prompt length $M$ in log-log axes. Training prompt length $N=2000$ and number of classes $c=6$. (d): Relationship between the inference error and the training prompt length $N$ in log-log axes. Test prompt length $M=2000$ and number of classes $c=6$.

Theorems & Definitions (40)

  • Definition 3.1
  • Theorem 3.3
  • Corollary 3.4
  • Theorem 3.6
  • Remark 3.7
  • Remark 3.8
  • Definition 4.1
  • Theorem 4.3
  • Theorem 4.5
  • Lemma 3.1: karimi2016linear
  • ...and 30 more