Table of Contents
Fetching ...

Birth of a Transformer: A Memory Viewpoint

Alberto Bietti, Vivien Cabannes, Diane Bouchacourt, Herve Jegou, Leon Bottou

TL;DR

This work tackles how transformers balance global knowledge with in-context information by constructing a synthetic bigram task that separates persistent world knowledge from context-specific cues. Using a simplified two-layer transformer and an associative-memory perspective, the authors show global bigrams are learned quickly while an induction head emerges through top‑down gradient steps that tune key–query memories to capture in-context associations. Theoretical analyses link population gradients to memory formation and demonstrate how gradient updates can recover useful associations from noisy residual streams. Overall, the study provides a memory-centric lens on learning dynamics in transformers, with implications for optimization, data preprocessing, and mechanistic interpretability in language models.

Abstract

Large language models based on transformers have achieved great empirical successes. However, as they are deployed more widely, there is a growing need to better understand their internal mechanisms in order to make them more reliable. These models appear to store vast amounts of knowledge from their training data, and to adapt quickly to new information provided in their context or prompt. We study how transformers balance these two types of knowledge by considering a synthetic setup where tokens are generated from either global or context-specific bigram distributions. By a careful empirical analysis of the training process on a simplified two-layer transformer, we illustrate the fast learning of global bigrams and the slower development of an "induction head" mechanism for the in-context bigrams. We highlight the role of weight matrices as associative memories, provide theoretical insights on how gradients enable their learning during training, and study the role of data-distributional properties.

Birth of a Transformer: A Memory Viewpoint

TL;DR

This work tackles how transformers balance global knowledge with in-context information by constructing a synthetic bigram task that separates persistent world knowledge from context-specific cues. Using a simplified two-layer transformer and an associative-memory perspective, the authors show global bigrams are learned quickly while an induction head emerges through top‑down gradient steps that tune key–query memories to capture in-context associations. Theoretical analyses link population gradients to memory formation and demonstrate how gradient updates can recover useful associations from noisy residual streams. Overall, the study provides a memory-centric lens on learning dynamics in transformers, with implications for optimization, data preprocessing, and mechanistic interpretability in language models.

Abstract

Large language models based on transformers have achieved great empirical successes. However, as they are deployed more widely, there is a growing need to better understand their internal mechanisms in order to make them more reliable. These models appear to store vast amounts of knowledge from their training data, and to adapt quickly to new information provided in their context or prompt. We study how transformers balance these two types of knowledge by considering a synthetic setup where tokens are generated from either global or context-specific bigram distributions. By a careful empirical analysis of the training process on a simplified two-layer transformer, we illustrate the fast learning of global bigrams and the slower development of an "induction head" mechanism for the in-context bigrams. We highlight the role of weight matrices as associative memories, provide theoretical insights on how gradients enable their learning during training, and study the role of data-distributional properties.
Paper Structure (56 sections, 5 theorems, 71 equations, 11 figures)

This paper contains 56 sections, 5 theorems, 71 equations, 11 figures.

Key Result

Lemma 1

Let $p$ be a data distribution over input-output tokens, and consider the following loss, where the input and output embeddings $W_E$ and $W_U$ are fixed: with $\ell$ the cross-entropy loss. The gradients of the population loss $L$ then take the form where $\hat{p}_W(y\!=\!k|x) =\sigma(W_U W w_E(z))_k$ are the model's predicted probabilities. Running gradient descent (with or without weight deca

Figures (11)

  • Figure 1: Induction head mechanism. Induction heads are a two-layer mechanism that can predict $b$ from a context $[\ldots, a, b, \ldots, a]$. The first layer is a previous token head, which attends to the previous token based on positional embeddings ($\color{black!40!green} p_t \to p_{t-1}$) and copies it after a remapping ($w_E(a) \to w_1(a) := W_O^1 W_V^1 w_E(a)$). The second layer is the induction head, which attends based on the output of the previous token head ($\color{black!30!red} w_E(a) \to w_1(a)$) and outputs the attended token, remapped to output embeddings ($\color{black!70!red} w_E(b) \to w_U(b)$). Boxes in the diagram represent different embeddings in superposition on each token's residual stream (we omit some irrelevant ones for clarity, e.g., positional embeddings in upper layers), and attention and output associations are shown with the associative memory viewpoint presented in Section \ref{['sec:memory']}.
  • Figure 2: Induction head behavior in attention maps observed on a 2-layer transformer trained on two variants of our synthetic dataset. Each row shows the attention pattern for predicting the next token. (left) The first layer head always attends to the previous token. (center) For fixed triggers $Q = \{a,t\}$, the second layer head mainly attends to tokens following such triggers. (right) For random triggers, the induction head mechanism is active for any repeated token (here the only trigger is $L$). Red and green boxes highlight tokens following previous occurrences of the query, with red boxes corresponding to "correct" output tokens $o_k$ following trigger tokens $q_k$.
  • Figure 3: Learning the induction head alone: in-context accuracy (top) and recall probes (bottom) with some layers frozen until iteration 300. The output matrix $W_O^2$ can and must be learned before the key-query matrices, but does not suffice for good accuracy. It is easier to learn $W_K^2$ before $W_K^1$, and $W_K^1$ stores initial context positions ($t < 64$) much faster than late positions.
  • Figure 4: Global vs in-context learning and data-distributional effects. (left) Loss on global (dashed) vs in-context (solid) tokens throughout training, for fixed or random trigger tokens $q_k$. The red curves fixes the trigger $q_1$ to the most frequent token, while the fixed triggers in blue curves are less common. (center) In-context accuracy with different training and test distributions $\pi_o$ for output tokens. Uniform leads to better generalization than global bigrams $\pi_b$. (right) Probe metrics throughout training: $W_O^2$ and $W_F$ eventually compete and deviate from our natural estimates.
  • Figure 5: Memory recall probes for the setting of Figure \ref{['fig:global_vs_local']}(left).
  • ...and 6 more figures

Theorems & Definitions (9)

  • Lemma 1: Gradients and associative memories
  • Lemma 2: Gradient associative memory with noisy inputs
  • Theorem 3: Learning induction head via three gradient steps, informal
  • Lemma 4: Gradient of second attention layer
  • Lemma 5: Gradient of first attention layer
  • proof
  • proof
  • proof
  • proof