Test-Time Training Provably Improves Transformers as In-context Learners
Halil Alperen Gozeten, M. Emrullah Ildiz, Xuechen Zhang, Mahdi Soltanolkotabi, Marco Mondelli, Samet Oymak
TL;DR
This work develops a theoretical framework for test-time training of one-layer linear transformers as in-context learners, deriving exact single-step update rules and optimal learning rates under isotropic and general covariance. It shows that TTT can reduce the effective sample size needed for in-context learning and can mitigate distribution shift, with phase transitions determining when warm-start vs. cold-start training is preferable. The authors corroborate theory with experiments on TabPFN and GPT-2, demonstrating up to 5x reductions in required in-context data and alignment-aware performance gains under distribution shifts. The findings offer a path toward more efficient, task-specific in-context learning with provable guarantees and practical guidance for when to deploy TTT in real systems.
Abstract
Test-time training (TTT) methods explicitly update the weights of a model to adapt to the specific test instance, and they have found success in a variety of settings, including most recently language modeling and reasoning. To demystify this success, we investigate a gradient-based TTT algorithm for in-context learning, where we train a transformer model on the in-context demonstrations provided in the test prompt. Specifically, we provide a comprehensive theoretical characterization of linear transformers when the update rule is a single gradient step. Our theory (i) delineates the role of alignment between pretraining distribution and target task, (ii) demystifies how TTT can alleviate distribution shift, and (iii) quantifies the sample complexity of TTT including how it can significantly reduce the eventual sample size required for in-context learning. As our empirical contribution, we study the benefits of TTT for TabPFN, a tabular foundation model. In line with our theory, we demonstrate that TTT significantly reduces the required sample size for tabular classification (3 to 5 times fewer) unlocking substantial inference efficiency with a negligible training cost.
