Table of Contents
Fetching ...

How Transformers Learn In-Context Recall Tasks? Optimality, Training Dynamics and Generalization

Quan Nguyen, Thanh Nguyen-Tang

TL;DR

This work analyzes how transformers learn in-context recall tasks using a one-layer decoder-only model with fixed embeddings and a trainable unembedding, value, and joint query-key. It shows Bayes-optimality for linear, ReLU, and softmax attentions in both noiseless and noisy settings, and proves linear convergence of normalized gradient descent to the Bayes risk, plus finite-sample PAC-style generalization guarantees. A key finding is that proper parameterization is crucial for achieving both Bayes-optimal in-distribution performance and out-of-distribution generalization, whereas gradient descent alone may fail to generalize without such structure. The results are supported by empirical validations across noiseless/noisy regimes, demonstrating the practical impact of parameterized one-layer transformers on in-context reasoning and OOD robustness.

Abstract

We study the approximation capabilities, convergence speeds and on-convergence behaviors of transformers trained on in-context recall tasks -- which requires to recognize the \emph{positional} association between a pair of tokens from in-context examples. Existing theoretical results only focus on the in-context reasoning behavior of transformers after being trained for the \emph{one} gradient descent step. It remains unclear what is the on-convergence behavior of transformers being trained by gradient descent and how fast the convergence rate is. In addition, the generalization of transformers in one-step in-context reasoning has not been formally investigated. This work addresses these gaps. We first show that a class of transformers with either linear, ReLU or softmax attentions, is provably Bayes-optimal for an in-context recall task. When being trained with gradient descent, we show via a finite-sample analysis that the expected loss converges at linear rate to the Bayes risks. Moreover, we show that the trained transformers exhibit out-of-distribution (OOD) generalization, i.e., generalizing to samples outside of the population distribution. Our theoretical findings are further supported by extensive empirical validations, showing that \emph{without} proper parameterization, models with larger expressive power surprisingly \emph{fail} to generalize OOD after being trained by gradient descent.

How Transformers Learn In-Context Recall Tasks? Optimality, Training Dynamics and Generalization

TL;DR

This work analyzes how transformers learn in-context recall tasks using a one-layer decoder-only model with fixed embeddings and a trainable unembedding, value, and joint query-key. It shows Bayes-optimality for linear, ReLU, and softmax attentions in both noiseless and noisy settings, and proves linear convergence of normalized gradient descent to the Bayes risk, plus finite-sample PAC-style generalization guarantees. A key finding is that proper parameterization is crucial for achieving both Bayes-optimal in-distribution performance and out-of-distribution generalization, whereas gradient descent alone may fail to generalize without such structure. The results are supported by empirical validations across noiseless/noisy regimes, demonstrating the practical impact of parameterized one-layer transformers on in-context reasoning and OOD robustness.

Abstract

We study the approximation capabilities, convergence speeds and on-convergence behaviors of transformers trained on in-context recall tasks -- which requires to recognize the \emph{positional} association between a pair of tokens from in-context examples. Existing theoretical results only focus on the in-context reasoning behavior of transformers after being trained for the \emph{one} gradient descent step. It remains unclear what is the on-convergence behavior of transformers being trained by gradient descent and how fast the convergence rate is. In addition, the generalization of transformers in one-step in-context reasoning has not been formally investigated. This work addresses these gaps. We first show that a class of transformers with either linear, ReLU or softmax attentions, is provably Bayes-optimal for an in-context recall task. When being trained with gradient descent, we show via a finite-sample analysis that the expected loss converges at linear rate to the Bayes risks. Moreover, we show that the trained transformers exhibit out-of-distribution (OOD) generalization, i.e., generalizing to samples outside of the population distribution. Our theoretical findings are further supported by extensive empirical validations, showing that \emph{without} proper parameterization, models with larger expressive power surprisingly \emph{fail} to generalize OOD after being trained by gradient descent.

Paper Structure

This paper contains 36 sections, 16 theorems, 84 equations, 5 figures, 1 table, 1 algorithm.

Key Result

Lemma 3.1

Let $\bm{\lambda} = \{\lambda_k \in \mathbb{R}_+: k \in {\mathcal{Q}}\}$ be a set of $\abs{{\mathcal{Q}}}$ of non-negative values. By setting ${\bm{U}} = [E(1)~E(2)~\dots~E(N)]^\top, {\bm{V}} = {\bm{I}}_d$ and ${\bm{W}} = \sum_{k \in {\mathcal{Q}}} \lambda_k E(k)\tilde{E}^\top(k)$, for both linear a

Figures (5)

  • Figure 1: Population and OOD test loss of models trained on the population loss in noiseless and noisy settings. First row, left to right: the population loss, OOD test losses of Origin (no parameterization) models and OOD test losses of re-parameterized models in noiseless learning. Second row, left to right: the corresponding losses as in the first row in the noisy setting with $\alpha = 0.5$.
  • Figure 2: Population and Unseen Output Test Losses of Reparam-Linear-W with $\eta =0.1$ (first row), $\eta = 0.5$ (second row). Population losses converge with limited generalization to unseen output words.
  • Figure 3: Origin-Linear with $\alpha =0.2$ (first row), $\alpha = 0.5$ (second row) and $\alpha = 0.8$ (third row). The learning rate is $\eta = 0.1$.
  • Figure 4: Reparam-Linear-W with $\alpha =0.2$ (first row), $\alpha = 0.5$ (second row) and $\alpha = 0.8$ (third row). The learning rate is $\eta = 0.1$.
  • Figure 5: Reparam-Linear with $\alpha =0.2$ (first row), $\alpha = 0.5$ (second row) and $\alpha = 0.8$ (third row). The learning rate is $\eta = 0.1$.

Theorems & Definitions (33)

  • Definition 2.1: Data Model - In-context Recall Tasks
  • Lemma 3.1
  • proof
  • Theorem 3.2
  • Theorem 3.3
  • Theorem 3.4
  • Lemma 4.1
  • Theorem 4.2
  • Lemma 5.1
  • Lemma 5.2
  • ...and 23 more