How Transformers Learn In-Context Recall Tasks? Optimality, Training Dynamics and Generalization
Quan Nguyen, Thanh Nguyen-Tang
TL;DR
This work analyzes how transformers learn in-context recall tasks using a one-layer decoder-only model with fixed embeddings and a trainable unembedding, value, and joint query-key. It shows Bayes-optimality for linear, ReLU, and softmax attentions in both noiseless and noisy settings, and proves linear convergence of normalized gradient descent to the Bayes risk, plus finite-sample PAC-style generalization guarantees. A key finding is that proper parameterization is crucial for achieving both Bayes-optimal in-distribution performance and out-of-distribution generalization, whereas gradient descent alone may fail to generalize without such structure. The results are supported by empirical validations across noiseless/noisy regimes, demonstrating the practical impact of parameterized one-layer transformers on in-context reasoning and OOD robustness.
Abstract
We study the approximation capabilities, convergence speeds and on-convergence behaviors of transformers trained on in-context recall tasks -- which requires to recognize the \emph{positional} association between a pair of tokens from in-context examples. Existing theoretical results only focus on the in-context reasoning behavior of transformers after being trained for the \emph{one} gradient descent step. It remains unclear what is the on-convergence behavior of transformers being trained by gradient descent and how fast the convergence rate is. In addition, the generalization of transformers in one-step in-context reasoning has not been formally investigated. This work addresses these gaps. We first show that a class of transformers with either linear, ReLU or softmax attentions, is provably Bayes-optimal for an in-context recall task. When being trained with gradient descent, we show via a finite-sample analysis that the expected loss converges at linear rate to the Bayes risks. Moreover, we show that the trained transformers exhibit out-of-distribution (OOD) generalization, i.e., generalizing to samples outside of the population distribution. Our theoretical findings are further supported by extensive empirical validations, showing that \emph{without} proper parameterization, models with larger expressive power surprisingly \emph{fail} to generalize OOD after being trained by gradient descent.
