Table of Contents
Fetching ...

Learning Linear Regression with Low-Rank Tasks in-Context

Kaito Takanami, Takashi Takahashi, Yoshiyuki Kabashima

TL;DR

The paper tackles the theoretical understanding of in-context learning (ICL) by analyzing a linear-attention transformer trained on low-rank regression tasks in the high-dimensional limit. It unveils a clean decomposition of ICL predictions into an algorithmic signal and two suppressible noise components, and shows that finite pre-training data induce an implicit regularization that stabilizes learning in low-rank settings. A phase transition governed by the relationship between task dimensionality and task diversity emerges, with clear implications for pre-training strategy and curriculum design. The results provide a principled framework for interpreting how transformers learn task structure and adapt to out-of-distribution scenarios, with concrete predictions for generalization under TM, IDG, and ODG protocols.

Abstract

In-context learning (ICL) is a key building block of modern large language models, yet its theoretical mechanisms remain poorly understood. It is particularly mysterious how ICL operates in real-world applications where tasks have a common structure. In this work, we address this problem by analyzing a linear attention model trained on low-rank regression tasks. Within this setting, we precisely characterize the distribution of predictions and the generalization error in the high-dimensional limit. Moreover, we find that statistical fluctuations in finite pre-training data induce an implicit regularization. Finally, we identify a sharp phase transition of the generalization error governed by task structure. These results provide a framework for understanding how transformers learn to learn the task structure.

Learning Linear Regression with Low-Rank Tasks in-Context

TL;DR

The paper tackles the theoretical understanding of in-context learning (ICL) by analyzing a linear-attention transformer trained on low-rank regression tasks in the high-dimensional limit. It unveils a clean decomposition of ICL predictions into an algorithmic signal and two suppressible noise components, and shows that finite pre-training data induce an implicit regularization that stabilizes learning in low-rank settings. A phase transition governed by the relationship between task dimensionality and task diversity emerges, with clear implications for pre-training strategy and curriculum design. The results provide a principled framework for interpreting how transformers learn task structure and adapt to out-of-distribution scenarios, with concrete predictions for generalization under TM, IDG, and ODG protocols.

Abstract

In-context learning (ICL) is a key building block of modern large language models, yet its theoretical mechanisms remain poorly understood. It is particularly mysterious how ICL operates in real-world applications where tasks have a common structure. In this work, we address this problem by analyzing a linear attention model trained on low-rank regression tasks. Within this setting, we precisely characterize the distribution of predictions and the generalization error in the high-dimensional limit. Moreover, we find that statistical fluctuations in finite pre-training data induce an implicit regularization. Finally, we identify a sharp phase transition of the generalization error governed by task structure. These results provide a framework for understanding how transformers learn to learn the task structure.

Paper Structure

This paper contains 48 sections, 7 theorems, 115 equations, 5 figures.

Key Result

Theorem 5.1

Recall that $\rho=r/D$ and $\kappa=M_0/D$. The empirical spectral distribution of $S$ converges almost surely to Here, $\mathbf{1}_{[s_-,s_+]}(s)$ is the indicator function. The support edges are given by $s_\pm = \left(\sqrt{\rho} \pm \sqrt{\kappa}\right)^2 / (\rho\kappa)$.

Figures (5)

  • Figure 1: Decomposition of generalization error reveals ICL's noise-reduction mechanism. The total generalization error (in theory with dashed lines, and in numerical experiments with error bars) and its decomposition into algorithmic (signal), memorization (noise), and structural (noise) components (solid lines) for each protocol. Parameters: $\gamma=1.5, \kappa=1.5, \rho=0.9, \sigma=0.3, \tilde{\alpha}=\alpha, D = 60$. The error bars represent the standard errors of the mean over $5$ independent numerical experiments per point.
  • Figure 2: Implicit regularization from finite data prevents learning instability. IDG error as a function of the task subspace dimensionality $\rho$. The figure compares the performance of the idealized limit (dashed line), with standard setting using finite sample ratio $\alpha$ (solid lines). Parameters: $D = 40, M_0 = 60, \tilde{L} = 160, M = 2400, \sigma = 0.01$. The error bars represent the standard errors of the mean over $5$ trials per point.
  • Figure 3: Phase transition in the model's capabilities. Generalization errors ($\mathcal{E}_{\text{TM}}, \mathcal{E}_{\text{IDG}}, \mathcal{E}_{\text{ODG}}$) as a function of the task difficulty $\rho$. The results confirm the predicted phase transition in generalization errors at $\rho = \kappa$. The error bars represent the standard errors of the mean over $5$ trials per point. Parameters: $\alpha=200 \text{ (solid lines)}, \alpha\to\infty \text{ (dashed lines)}, \tilde{\alpha}=4.0, \gamma=8.0, \kappa = 0.5, \sigma=0, D = 40$.
  • Figure 5: Comparison of the order parameters $Q, m, \bar{Q}, \bar{m}, Q_0, m_0$ obtained from the replica method (lines) with those obtained from the numerical experiments (error bars). Parameters for (A-D): $\lambda=0.1, \sigma=0.1, D=70$; (A)$\gamma=1.5, \rho=0.4, \alpha=2.0$; (B)$\kappa=1.0, \rho=0.4, \alpha=2.0$; (C)$\kappa=1.0, \gamma=1.5, \alpha=2.0$; (D)$\kappa=1.0, \gamma=1.5, \rho=0.4$. Error bars represent the standard error of the mean over $10$ trials per point.
  • Figure 6: Comparison of the generalization errors ($\mathcal{E}_{\text{TM}}, \mathcal{E}_{\text{IDG}}, \mathcal{E}_{\text{ODG}}$) obtained from the replica method (lines) with those obtained from the numerical experiments (error bars). Parameters for (A-D): $\kappa=4.0, \gamma = 0.5, \alpha=4.0, \lambda=0.1, \sigma=0.1, D=60$; (A)$\rho=0.9$; (B)$\rho=0.2$. Error bars represent the standard errors of the mean over $5$ trials per point.

Theorems & Definitions (12)

  • Theorem 5.1: Spectrum of the Task Matrix
  • Lemma C.2
  • proof
  • Lemma D.1
  • proof
  • Lemma D.2
  • proof
  • Proposition D.3
  • proof
  • Lemma D.4
  • ...and 2 more