In-context Learning for Mixture of Linear Regressions: Existence, Generalization and Training Dynamics
Yanhao Jin, Krishnakumar Balasubramanian, Lifeng Lai
TL;DR
The paper tackles in-context learning for mixtures of linear regressions by proving the existence of transformer architectures that internally implement EM with gradient-based M-steps, and it derives high-probability error bounds that scale with prompt length and SNR. It further analyzes pretraining generalization and the gradient-flow dynamics of single linear self-attention layers, showing convergence to global optima under structured initializations. Theoretical results are complemented by extensive simulations demonstrating competitive, often superior, performance of transformers relative to EM baselines across varying numbers of components, prompt lengths, and noise levels. These findings establish a quantitative framework for end-to-end ICL in MoR and suggest practical benefits for learning heterogeneous regression components via transformer-based prompts.
Abstract
We investigate the in-context learning capabilities of transformers for the $d$-dimensional mixture of linear regression model, providing theoretical insights into their existence, generalization bounds, and training dynamics. Specifically, we prove that there exists a transformer capable of achieving a prediction error of order $\mathcal{O}(\sqrt{d/n})$ with high probability, where $n$ represents the training prompt size in the high signal-to-noise ratio (SNR) regime. Moreover, we derive in-context excess risk bounds of order $\mathcal{O}(L/\sqrt{B})$ for the case of two mixtures, where $B$ denotes the number of training prompts, and $L$ represents the number of attention layers. The dependence of $L$ on the SNR is explicitly characterized, differing between low and high SNR settings. We further analyze the training dynamics of transformers with single linear self-attention layers, demonstrating that, with appropriately initialized parameters, gradient flow optimization over the population mean square loss converges to a global optimum. Extensive simulations suggest that transformers perform well on this task, potentially outperforming other baselines, such as the Expectation-Maximization algorithm.
