Table of Contents
Fetching ...

In-Context Learning of Linear Systems: Generalization Theory and Applications to Operator Learning

Frank Cole, Yulong Lu, Wuzhe Xu, Tianhao Zhang

TL;DR

The paper provides a theoretical framework for in-context learning of linear systems using a linear transformer, deriving an in-domain generalization bound that decays with prompt length and pre-training size. It introduces a novel task-diversity concept to characterize when pre-trained transformers generalize under task distribution shifts, and proves that diversity is a sufficient (and in some cases necessary) condition for robust OOD performance. The results extend to in-context operator learning and PDEs, offering an abstract operator-learning bound that translates into PDE Corollaries for elliptic problems. Numerical experiments on random matrices and linear elliptic PDEs validate the theoretical rates and demonstrate the practical impact of task diversity on out-of-domain generalization.

Abstract

We study theoretical guarantees for solving linear systems in-context using a linear transformer architecture. For in-domain generalization, we provide neural scaling laws that bound the generalization error in terms of the number of tasks and sizes of samples used in training and inference. For out-of-domain generalization, we find that the behavior of trained transformers under task distribution shifts depends crucially on the distribution of the tasks seen during training. We introduce a novel notion of task diversity and show that it defines a necessary and sufficient condition for pre-trained transformers generalize under task distribution shifts. We also explore applications of learning linear systems in-context, such as to in-context operator learning for PDEs. Finally, we provide some numerical experiments to validate the established theory.

In-Context Learning of Linear Systems: Generalization Theory and Applications to Operator Learning

TL;DR

The paper provides a theoretical framework for in-context learning of linear systems using a linear transformer, deriving an in-domain generalization bound that decays with prompt length and pre-training size. It introduces a novel task-diversity concept to characterize when pre-trained transformers generalize under task distribution shifts, and proves that diversity is a sufficient (and in some cases necessary) condition for robust OOD performance. The results extend to in-context operator learning and PDEs, offering an abstract operator-learning bound that translates into PDE Corollaries for elliptic problems. Numerical experiments on random matrices and linear elliptic PDEs validate the theoretical rates and demonstrate the practical impact of task diversity on out-of-domain generalization.

Abstract

We study theoretical guarantees for solving linear systems in-context using a linear transformer architecture. For in-domain generalization, we provide neural scaling laws that bound the generalization error in terms of the number of tasks and sizes of samples used in training and inference. For out-of-domain generalization, we find that the behavior of trained transformers under task distribution shifts depends crucially on the distribution of the tasks seen during training. We introduce a novel notion of task diversity and show that it defines a necessary and sufficient condition for pre-trained transformers generalize under task distribution shifts. We also explore applications of learning linear systems in-context, such as to in-context operator learning for PDEs. Finally, we provide some numerical experiments to validate the established theory.
Paper Structure (32 sections, 28 theorems, 157 equations, 6 figures)

This paper contains 32 sections, 28 theorems, 157 equations, 6 figures.

Key Result

Theorem 1

Adopt Assumption assumption:taskanddatadistr. Let $\widehat{\theta} = \textrm{arg}\min_{\|\theta\| \leq M} \mathcal{R}_{n,N}(\theta)$, where the norm is defined by $\|\theta\| = \max\{\|P\|_{\textrm{op}},\|Q\|_{\textrm{op}}\}$ and $M > 0$ is any positive number such that $M \geq \max(1, \|\Sigma^{-1 If in addition Equation eq:fstar holds, then where the implicit constants hidden in "$\lesssim$" d

Figures (6)

  • Figure 1: In-domain generalization with $d=10$. The left plot shows the relative $L^2$ error with respect to various training prompt length $n$ using $N=20000$ training tasks. The middle plot shows the relative $L^2$ error with respect to various inference prompt length $m$ using $N=20000$ training tasks. The right plot shows relative $L^2$ error with respect to various number of training tasks $N$ using a fixed training prompt length $n=10000$.
  • Figure 2: Out-of-domain generalization with $d=10$, $n=2000$ and $N=5000$. Training is performed with $D \sim \mathbf{U}_d[1, 2]$ and $y \sim \mathcal{N}(0, \Sigma_d(0)) = \mathcal{N}(0, \mathbf{I}_d)$. Testing involves different settings of $\mathbf{U}_d[a, b]$ and $\mathcal{N}(0, \Sigma_d(\rho))$. The left plot shows the relative $L^2$ error under shifts in $V(x)$: the solid blue curve represents the training distribution, while dashed lines indicate shifted test distributions. The right plot presents analogous results for shifts in $f(x)$, where blue curve represents the training distribution and dashed lines indicate shifted test distributions.
  • Figure 3: Diversity test with dimension $d=10$, training prompt length $n=2000$, number of tasks $N=5000$. Both plots show the relative $L^2$ error with respect to varying testing prompt length $m$. The left plot corresponds to a transformer initialized near the minimizer $(P, Q) = (\mathbf{I}_d, \mathbf{I}_d)$, while the right plot uses initialization $(P, Q) = (K, K^{-1})$, where $K$ is a diagonal matrix with entries sampled from the uniform distribution $U(1, 2)$.
  • Figure 4: In-domain generalization test of elliptic PDE \ref{['eqn: ellipticpde']} with $d=32$. The left plot shows the relative $H^1$ error with respect to various training prompt length $n$ using $N=20000$ training tasks. The middle plot shows the relative $H^1$ error with respect to various inference prompt length $m$ using $N=20000$ training tasks. The right plot shows relative $H^1$ error with respect to various number of training tasks $N$ using a fixed training prompt length $n=10000$.
  • Figure 5: Out-of-domain generalization test with $d=32$, $n=2000$ and $N=20000$. Training is performed with $(\alpha_1, \beta_1)= (2, 2)$, $(\alpha_2, \beta_2)= (2, 2)$ and $(\alpha_3, \beta_3)= (2, 2)$. Each columns shows the relative $H^1$ error under the distribution shift on $a(x)$, $V(x)$ and $f(x)$, respectively, with respect to varying inference prompt length $m$.
  • ...and 1 more figures

Theorems & Definitions (53)

  • Definition 1
  • Theorem 1
  • Definition 2
  • Theorem 2
  • Theorem 3
  • Corollary 1
  • Proposition 1
  • Theorem 4
  • Theorem 5
  • Corollary 2
  • ...and 43 more