Table of Contents
Fetching ...

In-context Learning for Mixture of Linear Regressions: Existence, Generalization and Training Dynamics

Yanhao Jin, Krishnakumar Balasubramanian, Lifeng Lai

TL;DR

The paper tackles in-context learning for mixtures of linear regressions by proving the existence of transformer architectures that internally implement EM with gradient-based M-steps, and it derives high-probability error bounds that scale with prompt length and SNR. It further analyzes pretraining generalization and the gradient-flow dynamics of single linear self-attention layers, showing convergence to global optima under structured initializations. Theoretical results are complemented by extensive simulations demonstrating competitive, often superior, performance of transformers relative to EM baselines across varying numbers of components, prompt lengths, and noise levels. These findings establish a quantitative framework for end-to-end ICL in MoR and suggest practical benefits for learning heterogeneous regression components via transformer-based prompts.

Abstract

We investigate the in-context learning capabilities of transformers for the $d$-dimensional mixture of linear regression model, providing theoretical insights into their existence, generalization bounds, and training dynamics. Specifically, we prove that there exists a transformer capable of achieving a prediction error of order $\mathcal{O}(\sqrt{d/n})$ with high probability, where $n$ represents the training prompt size in the high signal-to-noise ratio (SNR) regime. Moreover, we derive in-context excess risk bounds of order $\mathcal{O}(L/\sqrt{B})$ for the case of two mixtures, where $B$ denotes the number of training prompts, and $L$ represents the number of attention layers. The dependence of $L$ on the SNR is explicitly characterized, differing between low and high SNR settings. We further analyze the training dynamics of transformers with single linear self-attention layers, demonstrating that, with appropriately initialized parameters, gradient flow optimization over the population mean square loss converges to a global optimum. Extensive simulations suggest that transformers perform well on this task, potentially outperforming other baselines, such as the Expectation-Maximization algorithm.

In-context Learning for Mixture of Linear Regressions: Existence, Generalization and Training Dynamics

TL;DR

The paper tackles in-context learning for mixtures of linear regressions by proving the existence of transformer architectures that internally implement EM with gradient-based M-steps, and it derives high-probability error bounds that scale with prompt length and SNR. It further analyzes pretraining generalization and the gradient-flow dynamics of single linear self-attention layers, showing convergence to global optima under structured initializations. Theoretical results are complemented by extensive simulations demonstrating competitive, often superior, performance of transformers relative to EM baselines across varying numbers of components, prompt lengths, and noise levels. These findings establish a quantitative framework for end-to-end ICL in MoR and suggest practical benefits for learning heterogeneous regression components via transformer-based prompts.

Abstract

We investigate the in-context learning capabilities of transformers for the -dimensional mixture of linear regression model, providing theoretical insights into their existence, generalization bounds, and training dynamics. Specifically, we prove that there exists a transformer capable of achieving a prediction error of order with high probability, where represents the training prompt size in the high signal-to-noise ratio (SNR) regime. Moreover, we derive in-context excess risk bounds of order for the case of two mixtures, where denotes the number of training prompts, and represents the number of attention layers. The dependence of on the SNR is explicitly characterized, differing between low and high SNR settings. We further analyze the training dynamics of transformers with single linear self-attention layers, demonstrating that, with appropriately initialized parameters, gradient flow optimization over the population mean square loss converges to a global optimum. Extensive simulations suggest that transformers perform well on this task, potentially outperforming other baselines, such as the Expectation-Maximization algorithm.

Paper Structure

This paper contains 27 sections, 26 theorems, 229 equations, 4 figures, 1 algorithm.

Key Result

Theorem 3.1

Given the input matrix $H$ in the form of Input_of_constructed_transformer, there exists a transformer $\mathrm{TF}$ with the number of heads $M^{(\ell)}\leq M=4$ in each attention layers. This transformer $\mathrm{TF}$ can make prediction on $y_{n+1}$ by implementing gradient EM algorithm of MoR pr under the SNR condition equipped with $\mathcal{O}(T\log({n}/{d}))$ attention layers, the transfor

Figures (4)

  • Figure 1: Plot of excess testing risk by the transformer (and EM algorithm) v.s. prompt length with different SNRs on MoR tasks with $K=2,3,5,20$ components.
  • Figure 2: Plot of excess testing risk of the transformer v.s. the number of prompts with different SNRs.
  • Figure 3: Plot of excess testing risk of the transformer v.s. the dimension $d$ with different SNRs.
  • Figure 4: Plot of excess testing risk of the transformer v.s. the hidden dimension D with different SNRs.

Theorems & Definitions (45)

  • Definition 2.1
  • Remark 2.1
  • Definition 2.2: Attention only transformer
  • Theorem 3.1
  • Theorem 3.2
  • Theorem 3.3
  • Theorem 4.1: Generalization for pretraining
  • Remark 4.1
  • Theorem 4.2
  • Definition A.1: Approximability by sum of ReLUs
  • ...and 35 more