Table of Contents
Fetching ...

Mechanics of Next Token Prediction with Self-Attention

Yingcong Li, Yixiao Huang, M. Emrullah Ildiz, Ankit Singh Rawat, Samet Oymak

TL;DR

It is proved that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window.

Abstract

Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this simple training objective, they have led to revolutionary advances in natural language processing. Underlying this success is the self-attention mechanism. In this work, we ask: $\textit{What}$ $\textit{does}$ $\textit{a}$ $\textit{single}$ $\textit{self-attention}$ $\textit{layer}$ $\textit{learn}$ $\textit{from}$ $\textit{next-token}$ $\textit{prediction?}$ We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps: $\textbf{(1)}$ $\textbf{Hard}$ $\textbf{retrieval:}$ Given input sequence, self-attention precisely selects the $\textit{high-priority}$ $\textit{input}$ $\textit{tokens}$ associated with the last input token. $\textbf{(2)}$ $\textbf{Soft}$ $\textbf{composition:}$ It then creates a convex combination of the high-priority tokens from which the next token can be sampled. Under suitable conditions, we rigorously characterize these mechanics through a directed graph over tokens extracted from the training data. We prove that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window. Our theory relies on decomposing the model weights into a directional component and a finite component that correspond to hard retrieval and soft composition steps respectively. This also formalizes a related implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.

Mechanics of Next Token Prediction with Self-Attention

TL;DR

It is proved that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window.

Abstract

Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this simple training objective, they have led to revolutionary advances in natural language processing. Underlying this success is the self-attention mechanism. In this work, we ask: We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps: Given input sequence, self-attention precisely selects the associated with the last input token. It then creates a convex combination of the high-priority tokens from which the next token can be sampled. Under suitable conditions, we rigorously characterize these mechanics through a directed graph over tokens extracted from the training data. We prove that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window. Our theory relies on decomposing the model weights into a directional component and a finite component that correspond to hard retrieval and soft composition steps respectively. This also formalizes a related implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.
Paper Structure (34 sections, 21 theorems, 135 equations, 12 figures)

This paper contains 34 sections, 21 theorems, 135 equations, 12 figures.

Key Result

Theorem 1

Consider training a single-layer self-attention model with gradient descent. The combined attention weights $\bm{W}:=\bm{W}_K\bm{W}_Q^\top$ evolve as where $C\cdot\bm{W}_{\text{hard}}$ is the hard retrieval component selecting the high-priority tokens when $C\to\infty$; and $\bm{W}_{\text{soft}}$ is the soft composition component allocating nonzero softmax probabilities over selected tokens.

Figures (12)

  • Figure 1: Overview of our result on next-token prediction. We study the implicit bias of gradient descent where a 1-layer self-attention model is trained until convergence. We prove that, during test-time, this model implements a hard retrieval to precisely select the high-priority tokens and then outputs a convex combination of these as the output from which the next token can be sampled. The notion of high-priority is formalized through the strongly-connected components of a directed graph associated to the last input token.
  • Figure 2: A token-priority graph (TPG) is a directed graph derived from training data (see Sec \ref{['sec ntg']} for definition). The edges in TPG capture the input-output relationships between different tokens. A TPG can be partitioned into several SCCs depicted as dashed black squares. In light of Theorem \ref{['thm informal']}, black intra-SCC edges within each SCC induce the soft-composition component of the attention weights whereas the orange edges induce the hard-retrieval component enforcing the priority orders among various SCCs.
  • Figure 3: Illustration of token-priority graph (TPG). Given the input sequences and labels (next tokens), we construct the TPGs $\{{\cal{G}}^{(k)}\}_{k=1}^K$ according to the last token. Two TPGs ${\cal{G}}^{(1)}$ (left) and ${\cal{G}}^{(2)}$ (right) are constructed using the samples with $\bm{e}_1$ and $\bm{e}_2$ as the last tokens, respectively. In each graph, directed edges (label token$\to$input token) are added between tokens/nodes. Based on these directed edges, each graph can be partitioned into its strongly-connected components (SCCs, highlighted as dashed grey rectangles). Each SCC is a set of tokens where each token is reachable from every other token within that SCC. Further details are deferred to Section \ref{['sec ntg']}.
  • Figure 4: GD convergence of attention weight $\bm{W}$ when training with general dataset. (a) shows the directional convergence of $\bm{W}(\tau)$; while (b) presents the convergence of $\boldsymbol{\Pi}_{\mathcal{S}_{\text{fin}}}(\bm{W}(\tau))$.
  • Figure 5: GD convergence of attention weight $\bm{W}$ when training with acyclic dataset (Def. \ref{['def acyc']}). Correlation coefficient between $\bm{W}(\tau)$ and $\bm{W}^\text{svm}$ are presented.
  • ...and 7 more figures

Theorems & Definitions (28)

  • Theorem 1: informal
  • Lemma 1
  • Definition 1: Cyclic subspace
  • Lemma 2
  • Theorem 2
  • Theorem 3
  • Definition 2: Acyclic dataset
  • Lemma 3
  • Theorem 4
  • Definition 3: Cyclic subdataset
  • ...and 18 more