In-Context Learning with Representations: Contextual Generalization of Trained Transformers
Tong Yang, Yu Huang, Yingbin Liang, Yuejie Chi
TL;DR
This work provides a first-principles analysis of how transformers can acquire contextual (template) information from prompts to generalize to unseen examples and tasks without parameter updates. By modeling ICL as learning a linear-combination of fixed basis representations and training a one-layer, multi-head transformer with gradient descent, the authors prove linear convergence and reveal that inference implements ridge regression over the basis functions. The results show that contextual generalization emerges even with few, noisy prompt labels, and that the representation dimension relative to the prompt size critically shapes performance. The analysis relaxes several strong assumptions in prior theory and underscores the essential role of multi-head attention in enabling this contextual memorization and generalization. The accompanying experiments on synthetic data corroborate the theory and illustrate practical implications for prompt design and architecture choice.
Abstract
In-context learning (ICL) refers to a remarkable capability of pretrained large language models, which can learn a new task given a few examples during inference. However, theoretical understanding of ICL is largely under-explored, particularly whether transformers can be trained to generalize to unseen examples in a prompt, which will require the model to acquire contextual knowledge of the prompt for generalization. This paper investigates the training dynamics of transformers by gradient descent through the lens of non-linear regression tasks. The contextual generalization here can be attained via learning the template function for each task in-context, where all template functions lie in a linear space with $m$ basis functions. We analyze the training dynamics of one-layer multi-head transformers to in-contextly predict unlabeled inputs given partially labeled prompts, where the labels contain Gaussian noise and the number of examples in each prompt are not sufficient to determine the template. Under mild assumptions, we show that the training loss for a one-layer multi-head transformer converges linearly to a global minimum. Moreover, the transformer effectively learns to perform ridge regression over the basis functions. To our knowledge, this study is the first provable demonstration that transformers can learn contextual (i.e., template) information to generalize to both unseen examples and tasks when prompts contain only a small number of query-answer pairs.
