Table of Contents
Fetching ...

Transformers are Minimax Optimal Nonparametric In-Context Learners

Juno Kim, Tai Nakamaki, Taiji Suzuki

TL;DR

It is shown that sufficiently trained transformers can achieve -- and even improve upon -- the minimax optimal estimation risk in context by encoding the most relevant basis representations during pretraining.

Abstract

In-context learning (ICL) of large language models has proven to be a surprisingly effective method of learning a new task from only a few demonstrative examples. In this paper, we study the efficacy of ICL from the viewpoint of statistical learning theory. We develop approximation and generalization error bounds for a transformer composed of a deep neural network and one linear attention layer, pretrained on nonparametric regression tasks sampled from general function spaces including the Besov space and piecewise $γ$-smooth class. We show that sufficiently trained transformers can achieve -- and even improve upon -- the minimax optimal estimation risk in context by encoding the most relevant basis representations during pretraining. Our analysis extends to high-dimensional or sequential data and distinguishes the \emph{pretraining} and \emph{in-context} generalization gaps. Furthermore, we establish information-theoretic lower bounds for meta-learners w.r.t. both the number of tasks and in-context examples. These findings shed light on the roles of task diversity and representation learning for ICL.

Transformers are Minimax Optimal Nonparametric In-Context Learners

TL;DR

It is shown that sufficiently trained transformers can achieve -- and even improve upon -- the minimax optimal estimation risk in context by encoding the most relevant basis representations during pretraining.

Abstract

In-context learning (ICL) of large language models has proven to be a surprisingly effective method of learning a new task from only a few demonstrative examples. In this paper, we study the efficacy of ICL from the viewpoint of statistical learning theory. We develop approximation and generalization error bounds for a transformer composed of a deep neural network and one linear attention layer, pretrained on nonparametric regression tasks sampled from general function spaces including the Besov space and piecewise -smooth class. We show that sufficiently trained transformers can achieve -- and even improve upon -- the minimax optimal estimation risk in context by encoding the most relevant basis representations during pretraining. Our analysis extends to high-dimensional or sequential data and distinguishes the \emph{pretraining} and \emph{in-context} generalization gaps. Furthermore, we establish information-theoretic lower bounds for meta-learners w.r.t. both the number of tasks and in-context examples. These findings shed light on the roles of task diversity and representation learning for ICL.
Paper Structure (56 sections, 32 theorems, 186 equations, 3 figures)

This paper contains 56 sections, 32 theorems, 186 equations, 3 figures.

Key Result

Theorem 3.1

There exists a universal constant $C$ such that for any $\epsilon>0$ such that $\mathop{\mathrm{\mathcal{V}}}\nolimits(\mathop{\mathrm{\mathcal{T}}}\nolimits_N,\lVert\cdot\rVert_{L^\infty},\epsilon)\geq 1$,

Figures (3)

  • Figure 1: Architecture of the compared models. Each model contains two MLP components, all attention layers are single-head and LayerNorm is not included. (a),(b) implement the simplified reparametrization for attention, while all layers in (c) utilize the full embeddings. The input dimension is 8 and all hidden layer and DNN output widths are 32. The query prediction is read off the last entry of the output at the query position.
  • Figure 2: Training and test curves for the ICL pretraining objective. We use the Adam optimizer with a learning rate of 0.02 for all layers. For the task class we take $\alpha=1$, $p=q=\infty$, $T=n=512$ and generate samples from random combinations of order 2 wavelets.
  • Figure 3: Training and test losses of the three models after 50 epochs while varying (a) DNN width $N$; (b) number of in-context samples $n$; (c) number of tasks $T$. For (a), the widths of all hidden layers also vary with $N$. We take the median over 5 runs for robustness.

Theorems & Definitions (42)

  • Remark 2.1
  • Theorem 3.1: Schmidt20, Lemma 4, adapted
  • proof
  • Proposition 3.2
  • Lemma 3.3
  • Definition 4.1: Besov space
  • Proposition 4.2: Donoho98
  • Definition 4.3: B-spline wavelet basis
  • Lemma 4.4
  • Theorem 4.5: minimax optimality of ICL in Besov space
  • ...and 32 more