In-Context Convergence of Transformers
Yu Huang, Yuan Cheng, Yingbin Liang
TL;DR
This work tackles the theoretical understanding of in-context learning with softmax-attention transformers by analyzing a one-layer model trained via gradient descent on linear-function tasks. It introduces a structured data framework with balanced and imbalanced feature distributions and derives convergence results: a finite-time, near-zero error guarantee for balanced features and a four-phase, stage-wise convergence for imbalanced ones, driven by the evolving bilinear attention weights $A_k^{(t)}$ and $B_{k,n}^{(t)}$. The study shows that GD-guided softmax attention can concentrate attention on the relevant feature and predict unseen prompts accurately, thereby shedding light on the in-context learning mechanism in nonlinear transformers. The results have implications for understanding when and how in-context knowledge can be learned and transferred in practical transformer architectures, and they introduce analytic techniques that may extend to broader models and tasks.
Abstract
Transformers have recently revolutionized many domains in modern machine learning and one salient discovery is their remarkable in-context learning capability, where models can solve an unseen task by utilizing task-specific prompts without further parameters fine-tuning. This also inspired recent theoretical studies aiming to understand the in-context learning mechanism of transformers, which however focused only on linear transformers. In this work, we take the first step toward studying the learning dynamics of a one-layer transformer with softmax attention trained via gradient descent in order to in-context learn linear function classes. We consider a structured data model, where each token is randomly sampled from a set of feature vectors in either balanced or imbalanced fashion. For data with balanced features, we establish the finite-time convergence guarantee with near-zero prediction error by navigating our analysis over two phases of the training dynamics of the attention map. More notably, for data with imbalanced features, we show that the learning dynamics take a stage-wise convergence process, where the transformer first converges to a near-zero prediction error for the query tokens of dominant features, and then converges later to a near-zero prediction error for the query tokens of under-represented features, respectively via one and four training phases. Our proof features new techniques for analyzing the competing strengths of two types of attention weights, the change of which determines different training phases.
