Table of Contents
Fetching ...

Towards Understanding the Universality of Transformers for Next-Token Prediction

Michael E. Sander, Gabriel Peyré

TL;DR

The paper investigates why causal Transformers exhibit robust next-token prediction by proposing a kernel-based interpretation in a vector-valued RKHS where $x_{t+1}=f(x_t)$. It proves the existence of explicit Transformer constructions (with linear, exponential, or softmax attention) that implement causal kernel descent and asymptotically recover $x_{t+1}$ from $x_{1:t}$ for specific $f$ (linear) or periodic sequences, with a Neural ODE viewpoint for the infinite-depth limit. A causal-descent framework is developed, showing convergence via Kaczmarz-like arguments and enabling Transformer realizations of the updates through augmented tokens. Experiments validate the theory, demonstrate convergence behavior, and indicate the approach can generalize to broader mappings $f$, suggesting a path toward understanding in-context learning universality in autoregressive prediction.

Abstract

Causal Transformers are trained to predict the next token for a given context. While it is widely accepted that self-attention is crucial for encoding the causal structure of sequences, the precise underlying mechanism behind this in-context autoregressive learning ability remains unclear. In this paper, we take a step towards understanding this phenomenon by studying the approximation ability of Transformers for next-token prediction. Specifically, we explore the capacity of causal Transformers to predict the next token $x_{t+1}$ given an autoregressive sequence $(x_1, \dots, x_t)$ as a prompt, where $ x_{t+1} = f(x_t) $, and $ f $ is a context-dependent function that varies with each sequence. On the theoretical side, we focus on specific instances, namely when $ f $ is linear or when $ (x_t)_{t \geq 1} $ is periodic. We explicitly construct a Transformer (with linear, exponential, or softmax attention) that learns the mapping $f$ in-context through a causal kernel descent method. The causal kernel descent method we propose provably estimates $x_{t+1} $ based solely on past and current observations $ (x_1, \dots, x_t) $, with connections to the Kaczmarz algorithm in Hilbert spaces. We present experimental results that validate our theoretical findings and suggest their applicability to more general mappings $f$.

Towards Understanding the Universality of Transformers for Next-Token Prediction

TL;DR

The paper investigates why causal Transformers exhibit robust next-token prediction by proposing a kernel-based interpretation in a vector-valued RKHS where . It proves the existence of explicit Transformer constructions (with linear, exponential, or softmax attention) that implement causal kernel descent and asymptotically recover from for specific (linear) or periodic sequences, with a Neural ODE viewpoint for the infinite-depth limit. A causal-descent framework is developed, showing convergence via Kaczmarz-like arguments and enabling Transformer realizations of the updates through augmented tokens. Experiments validate the theory, demonstrate convergence behavior, and indicate the approach can generalize to broader mappings , suggesting a path toward understanding in-context learning universality in autoregressive prediction.

Abstract

Causal Transformers are trained to predict the next token for a given context. While it is widely accepted that self-attention is crucial for encoding the causal structure of sequences, the precise underlying mechanism behind this in-context autoregressive learning ability remains unclear. In this paper, we take a step towards understanding this phenomenon by studying the approximation ability of Transformers for next-token prediction. Specifically, we explore the capacity of causal Transformers to predict the next token given an autoregressive sequence as a prompt, where , and is a context-dependent function that varies with each sequence. On the theoretical side, we focus on specific instances, namely when is linear or when is periodic. We explicitly construct a Transformer (with linear, exponential, or softmax attention) that learns the mapping in-context through a causal kernel descent method. The causal kernel descent method we propose provably estimates based solely on past and current observations , with connections to the Kaczmarz algorithm in Hilbert spaces. We present experimental results that validate our theoretical findings and suggest their applicability to more general mappings .
Paper Structure (43 sections, 9 theorems, 73 equations, 4 figures)

This paper contains 43 sections, 9 theorems, 73 equations, 4 figures.

Key Result

Proposition 1

There exists a sequence of one-layer and $2$-heads causal Transformer $\mathcal{T}^n_0$ with $\mathcal{N}=\text{softmax}$ followed by a feedforward layer, such that $\mathcal{T}_0(x_{0:t}) \coloneqq \lim_{n\to +\infty} \mathcal{T}^n_0(x_{0:t}) = e^0_t$.

Figures (4)

  • Figure 1: Illustration of the method proposed in this paper. Given a sequence $x_{1:t}$, a first layer $\mathcal{T}_0$ computes augmented tokens $e^0_{1:t}$. Next, a stack of $n$ identical Transformer layers $\mathcal{T}$ with residual connections iteratively update the tokens $e^k_{1:t}$, following the causal kernel descent method introduced in Section \ref{['sec:causal_kernel_descent']}. For autoregressive sequences presented in Assumption \ref{['def:seq']}, and under specific instances outlined in Assumption \ref{['def:instances']}, projecting $e^n_t$ with a projector $P$ yields an estimate $u^n_t$ of $x_{t+1}$ as $n$ and $t$ approach $+\infty$, as stated in Theorem \ref{['thm:main']}.
  • Figure 2: For some random vectors $\nu$ (green) and $\nu_1, \cdots \nu_6$ in $S^2$, we display $P_6\nu$ (grey), $P_{5}P_6 \nu$ (orange), ... and $P_1 \cdots P_6 \nu$.
  • Figure 3: Evolution of the squared error $\|u^{\star}_t - x_{t+1}\|^2$ with $t$ for different scenarios. The curves are averaged over five sequences $x_{1:t}$. Left: instances (1) and (2) (with random $W$ and $\Omega$), illustrating Theorems \ref{['thm:linear_linear']} and \ref{['thm:stationary']} ($d = 15$). Center: instance $(3)$, illustrating Theorem \ref{['thm:periodic']} ($d = 15$, the period $t_p$ is randomly sampled between $20$ and $40$, and a random sequence is repeated $t_p$ times). Right: instance (4) described in Section \ref{['sec:experiments']} ($d = 4$).
  • Figure 4: Errors $\frac{1}{d}\|\mathcal{G}(x_{1:t})- x_{t+1}\|^2$ against $t$ for $\mathcal{G} \in \{\mathcal{M}, \mathcal{M}^n_{\theta_0}, \mathcal{M}^n_{\theta_{\star}} \}$. Results are averaged over the whole test set.

Theorems & Definitions (19)

  • Proposition 1
  • Theorem 1: On the expressivity of Transformers for Next Token Prediction
  • Remark 1
  • Proposition 2
  • Proposition 3
  • Proposition 4: Estimate Update Recursion
  • Theorem 2: $k = k_{\text{id}}$, linear recursions
  • Theorem 3: $k=k_{\text{exp}}$, linear recursions
  • Theorem 4: $k=k_{\text{exp}}$, periodic recursions.
  • Proposition 5
  • ...and 9 more