Table of Contents
Fetching ...

Out-of-Distribution Generalization of In-Context Learning: A Low-Dimensional Subspace Perspective

Soo Min Kwon, Alec S. Xu, Can Yaras, Laura Balzano, Qing Qu

TL;DR

This work aims to demystify the out-of-distribution (OOD) capabilities of in-context learning (ICL) by studying linear regression tasks parameterized with low-rank covariance matrices, and proves an interesting property of ICL: when trained on task vectors drawn from a union of low-dimensional subspaces, ICL can generalize to any subspace within their span, given sufficiently long prompt lengths.

Abstract

This work aims to demystify the out-of-distribution (OOD) capabilities of in-context learning (ICL) by studying linear regression tasks parameterized with low-rank covariance matrices. With such a parameterization, we can model distribution shifts as a varying angle between the subspace of the training and testing covariance matrices. We prove that a single-layer linear attention model incurs a test risk with a non-negligible dependence on the angle, illustrating that ICL is not robust to such distribution shifts. However, using this framework, we also prove an interesting property of ICL: when trained on task vectors drawn from a union of low-dimensional subspaces, ICL can generalize to any subspace within their span, given sufficiently long prompt lengths. This suggests that the OOD generalization ability of Transformers may actually stem from the new task lying within the span of those encountered during training. We empirically show that our results also hold for models such as GPT-2, and conclude with (i) experiments on how our observations extend to nonlinear function classes and (ii) results on how LoRA has the ability to capture distribution shifts.

Out-of-Distribution Generalization of In-Context Learning: A Low-Dimensional Subspace Perspective

TL;DR

This work aims to demystify the out-of-distribution (OOD) capabilities of in-context learning (ICL) by studying linear regression tasks parameterized with low-rank covariance matrices, and proves an interesting property of ICL: when trained on task vectors drawn from a union of low-dimensional subspaces, ICL can generalize to any subspace within their span, given sufficiently long prompt lengths.

Abstract

This work aims to demystify the out-of-distribution (OOD) capabilities of in-context learning (ICL) by studying linear regression tasks parameterized with low-rank covariance matrices. With such a parameterization, we can model distribution shifts as a varying angle between the subspace of the training and testing covariance matrices. We prove that a single-layer linear attention model incurs a test risk with a non-negligible dependence on the angle, illustrating that ICL is not robust to such distribution shifts. However, using this framework, we also prove an interesting property of ICL: when trained on task vectors drawn from a union of low-dimensional subspaces, ICL can generalize to any subspace within their span, given sufficiently long prompt lengths. This suggests that the OOD generalization ability of Transformers may actually stem from the new task lying within the span of those encountered during training. We empirically show that our results also hold for models such as GPT-2, and conclude with (i) experiments on how our observations extend to nonlinear function classes and (ii) results on how LoRA has the ability to capture distribution shifts.

Paper Structure

This paper contains 63 sections, 13 theorems, 147 equations, 8 figures.

Key Result

Proposition 1

Let $g_{\texttt{ATT}}^\star$ denote the optimal linear attention model corresponding to the independent data setting in Equation (eq:vanilla_setup). For all $j \in [m+1]$, suppose that the prompts at test time are constructed with features $\mathbf{x}_j \sim \mathcal{N}(\mathbf{0}, \mathbf{I}_d)$ an and $\mathbf{\Sigma}_t \in \mathbb{R}^{d\times d}$ is from Equation (eqn:covariance_t). Then, we ha

Figures (8)

  • Figure 1: Overview of this paper. We consider two models: one trained with task vectors drawn from a single subspace, and one with task vectors drawn from a union of subspaces. At inference, we test both models using a task vector at an angle between two subspaces. The single subspace model fails to generalize under distribution shifts, while the latter generalizes across all angles.
  • Figure 2: Plot of the normalized test risk as a function of the prompt length for a linear Transformer (left) and a nonlinear Transformer (right) under covariance shifts. As the covariance at test time shifts away from the covariance used at training time as a function of $\theta$, the test risk exhibits a non-negligible dependence on $\theta$ for both the linear and nonlinear Transformer. Moreover, for both models, the test risk exactly matches the predicted risk from Proposition \ref{['prop:neg_result']}.
  • Figure 3: Plot of the test risk as a function of the prompt length for a linear Transformer (left) and a nonlinear Transformer (right). When the prompt length at test time is large enough, the test risk goes nearly to zero for all $\theta \in \left[0, \frac{\pi}{2} \right]$, corroborating Theorem \ref{['thm:mix_two_subspaces']} in that both linear and nonlinear Transformers can generalize to the span of the training task vectors at test-time.
  • Figure 4: Left: Phase plot of the test risk as we vary the angle between $\mathbf{\Sigma}_s$ and $\mathbf{\Sigma}_t$ and the prompt length with $m=n$ for a linear attention model trained with a mixture of Gaussians. The test risk is low across all angle shifts, and decreases further as the prompt length increases. Right: Plot of the test risk as a function of the prompt length for a case in which $\mathbf{\Sigma}_s \neq \mathbf{\Sigma}_t$ but with $\theta = 0$, following the OOD example in Gatmiry at al gatmirylooped2024. This serves to explain why ICL can seemingly do OOD generalization as observed in the literature.
  • Figure 5: Experiments demonstrating that LoRA can be used to adapt to distribution shifts. Left: Test risk using the optimal LoRA adapters, showing that the risk approaches zero as the prompt length increases. Right: Subspace error between the learned and analytical LoRA adapters, demonstrating that the optimal adapters can be recovered using gradient descent over $5$ random runs. The subspace error is defined in Equation (\ref{['eqn:subspace_err']}).
  • ...and 3 more figures

Theorems & Definitions (25)

  • Proposition 1: Task Distribution Shift
  • Theorem 1: Test Risk under the Span of Covariance Matrices
  • Theorem 2
  • Corollary 1
  • Proposition 2: Task Distribution Shift with Different Angles
  • Proposition 3: Feature Distribution Shift
  • Lemma 1: Test Risk under General Task Distribution Shift
  • proof
  • proof
  • proof
  • ...and 15 more