Table of Contents
Fetching ...

Pretrained transformer efficiently learns low-dimensional target functions in-context

Kazusato Oko, Yujin Song, Taiji Suzuki, Denny Wu

TL;DR

This result highlights the adaptivity of the pretrained transformer to low-dimensional structures of the function class, which enables sample-efficient ICL that outperforms estimators that only have access to the in-context data.

Abstract

Transformers can efficiently learn in-context from example demonstrations. Most existing theoretical analyses studied the in-context learning (ICL) ability of transformers for linear function classes, where it is typically shown that the minimizer of the pretraining loss implements one gradient descent step on the least squares objective. However, this simplified linear setting arguably does not demonstrate the statistical efficiency of ICL, since the pretrained transformer does not outperform directly solving linear regression on the test prompt. In this paper, we study ICL of a nonlinear function class via transformer with nonlinear MLP layer: given a class of \textit{single-index} target functions $f_*(\boldsymbol{x}) = σ_*(\langle\boldsymbol{x},\boldsymbolβ\rangle)$, where the index features $\boldsymbolβ\in\mathbb{R}^d$ are drawn from a $r$-dimensional subspace, we show that a nonlinear transformer optimized by gradient descent (with a pretraining sample complexity that depends on the \textit{information exponent} of the link functions $σ_*$) learns $f_*$ in-context with a prompt length that only depends on the dimension of the distribution of target functions $r$; in contrast, any algorithm that directly learns $f_*$ on test prompt yields a statistical complexity that scales with the ambient dimension $d$. Our result highlights the adaptivity of the pretrained transformer to low-dimensional structures of the function class, which enables sample-efficient ICL that outperforms estimators that only have access to the in-context data.

Pretrained transformer efficiently learns low-dimensional target functions in-context

TL;DR

This result highlights the adaptivity of the pretrained transformer to low-dimensional structures of the function class, which enables sample-efficient ICL that outperforms estimators that only have access to the in-context data.

Abstract

Transformers can efficiently learn in-context from example demonstrations. Most existing theoretical analyses studied the in-context learning (ICL) ability of transformers for linear function classes, where it is typically shown that the minimizer of the pretraining loss implements one gradient descent step on the least squares objective. However, this simplified linear setting arguably does not demonstrate the statistical efficiency of ICL, since the pretrained transformer does not outperform directly solving linear regression on the test prompt. In this paper, we study ICL of a nonlinear function class via transformer with nonlinear MLP layer: given a class of \textit{single-index} target functions , where the index features are drawn from a -dimensional subspace, we show that a nonlinear transformer optimized by gradient descent (with a pretraining sample complexity that depends on the \textit{information exponent} of the link functions ) learns in-context with a prompt length that only depends on the dimension of the distribution of target functions ; in contrast, any algorithm that directly learns on test prompt yields a statistical complexity that scales with the ambient dimension . Our result highlights the adaptivity of the pretrained transformer to low-dimensional structures of the function class, which enables sample-efficient ICL that outperforms estimators that only have access to the in-context data.

Paper Structure

This paper contains 52 sections, 42 theorems, 121 equations, 2 figures, 1 algorithm.

Key Result

Theorem 1

Let $f:(\boldsymbol{x}_1,y_1,\dotsc,\boldsymbol{x}_N,y_N,\boldsymbol{x})\mapsto y$ be a transformer with nonlinear MLP layer pretrained with gradient descent (Algorithm alg:pretraining) on the single-index regression task eq:single-index. With probability at least 0.99, the model $f$ achieves in-con where $Q,P$ are the information exponent and the highest degree of link functions, respectively.

Figures (2)

  • Figure 1: In-context generalization error of kernel ridge regression, neural network + gradient descent, and pretrained transformer. The target function is a polynomial single-index model. We fix $r=8$ and vary $d=16,32$.
  • Figure 2: In-context sample complexity of GPT-2 model pretrained on Gaussian single-index function (see Section \ref{['sec:experment-setting']} for details) of degree-4 polynomial. Observe that $(a)$ the ICL risk curve overlaps for different ambient dimensions $d$ but the same target (subspace) dimensionality $r$, and $(b)$ the required sample size $N^*$ becomes larger as $r$ increases.

Theorems & Definitions (48)

  • Theorem : Informal
  • Remark 1
  • Remark 2
  • Theorem 1
  • Remark 3
  • Proposition 2: Informal
  • Definition 3
  • Lemma 4
  • Lemma 5
  • Lemma 6
  • ...and 38 more