How Do Transformers Learn In-Context Beyond Simple Functions? A Case Study on Learning with Representations
Tianyu Guo, Wei Hu, Song Mei, Huan Wang, Caiming Xiong, Silvio Savarese, Yu Bai
TL;DR
This work studies in-context learning (ICL) when task labels depend on inputs through a fixed representation $\Phi^\star$ followed by a varying linear readout, addressing a more realistic setting than simple function classes. It provides constructive theory showing decoder transformers can implement in-context ridge regression on the representations with mild depth, and validates these ideas empirically on synthetic data, observing a clear division where lower layers compute $\Phi^\star(\mathbf{x})$ and upper layers perform linear ICL. The paper also develops probing and pasting techniques to reveal mechanisms such as copying of inputs and representations and the upper-module’s ability to carry linear ICL independently, including in mixture-representation scenarios. These results offer mechanistic insight into how transformers could realize ICL in more complex, representation-based tasks and lay groundwork for extending to real-world representations. The findings highlight practical implications for designing prompts and architectures that separate representation learning from in-context adaptation, potentially improving robustness and interpretability of ICL in large language models.
Abstract
While large language models based on the transformer architecture have demonstrated remarkable in-context learning (ICL) capabilities, understandings of such capabilities are still in an early stage, where existing theory and mechanistic understanding focus mostly on simple scenarios such as learning simple function classes. This paper takes initial steps on understanding ICL in more complex scenarios, by studying learning with representations. Concretely, we construct synthetic in-context learning problems with a compositional structure, where the label depends on the input through a possibly complex but fixed representation function, composed with a linear function that differs in each instance. By construction, the optimal ICL algorithm first transforms the inputs by the representation function, and then performs linear ICL on top of the transformed dataset. We show theoretically the existence of transformers that approximately implement such algorithms with mild depth and size. Empirically, we find trained transformers consistently achieve near-optimal ICL performance in this setting, and exhibit the desired dissection where lower layers transforms the dataset and upper layers perform linear ICL. Through extensive probing and a new pasting experiment, we further reveal several mechanisms within the trained transformers, such as concrete copying behaviors on both the inputs and the representations, linear ICL capability of the upper layers alone, and a post-ICL representation selection mechanism in a harder mixture setting. These observed mechanisms align well with our theory and may shed light on how transformers perform ICL in more realistic scenarios.
