On the Training Convergence of Transformers for In-Context Classification of Gaussian Mixtures
Wei Shen, Ruida Zhou, Jing Yang, Cong Shen
TL;DR
This work provides a rigorous analysis of training dynamics for in-context classification using transformers on Gaussian mixtures. It proves linear convergence of a single-layer transformer under gradient descent to a global minimizer $W^*=2(\Lambda^{-1}+G)$ with $\|G\|_{\max}=O(1/N)$, and it derives finite-prompt and testing-prompt error bounds that decay as $O(1/N)$ and $O(1/\sqrt{M})$, respectively. Extending to multi-class problems, it shows $W^*=c(\Lambda^{-1}+G)$ with $\|G\|_{\max}=O(c/N)$ and establishes an $O(c^2N^{-1}+c^{3/2}M^{-1/2})$ inference-error bound, linking the learned transformer to approximate linear discriminant/softmax classifiers. Experiments corroborate the theory, reveal robustness distinctions between 1-layer and 3-layer models under distributional shifts, and demonstrate superior ICL performance over traditional baselines on Gaussian mixtures, highlighting the practical relevance of prompt-length and class-count effects.
Abstract
Although transformers have demonstrated impressive capabilities for in-context learning (ICL) in practice, theoretical understanding of the underlying mechanism that allows transformers to perform ICL is still in its infancy. This work aims to theoretically study the training dynamics of transformers for in-context classification tasks. We demonstrate that, for in-context classification of Gaussian mixtures under certain assumptions, a single-layer transformer trained via gradient descent converges to a globally optimal model at a linear rate. We further quantify the impact of the training and testing prompt lengths on the ICL inference error of the trained transformer. We show that when the lengths of training and testing prompts are sufficiently large, the prediction of the trained transformer approaches the ground truth distribution of the labels. Experimental results corroborate the theoretical findings.
