Table of Contents
Fetching ...

The Evolution of Statistical Induction Heads: In-Context Learning Markov Chains

Benjamin L. Edelman, Ezra Edelman, Surbhi Goel, Eran Malach, Nikolaos Tsilivis

TL;DR

We present a controlled study of in-context learning (ICL) by training Transformers on an ICL-MC task where each sequence is generated from a Markov chain drawn from a Dirichlet prior. The main contribution is the discovery of statistical induction heads that implement in-context bigram statistics, with learning unfolding in distinct phases from uniform predictions to unigram-based, and finally to Bayes-optimal bigram predictions. A minimal linear-transformer model corroborates the two-step gradient dynamics, showing the second layer learns first and that unigram signals can hinder rapid formation of the bigram solution; extending to $n$-grams demonstrates generalization of the hierarchical learning phenomenon. These results illuminate mechanistic pathways for ICL in LLMs and suggest how simple priors and curriculum-like shifts shape the emergence of complex in-context algorithms, with potential implications for understanding and improving in-context reasoning in real language models.

Abstract

Large language models have the ability to generate text that mimics patterns in their inputs. We introduce a simple Markov Chain sequence modeling task in order to study how this in-context learning (ICL) capability emerges. In our setting, each example is sampled from a Markov chain drawn from a prior distribution over Markov chains. Transformers trained on this task form \emph{statistical induction heads} which compute accurate next-token probabilities given the bigram statistics of the context. During the course of training, models pass through multiple phases: after an initial stage in which predictions are uniform, they learn to sub-optimally predict using in-context single-token statistics (unigrams); then, there is a rapid phase transition to the correct in-context bigram solution. We conduct an empirical and theoretical investigation of this multi-phase process, showing how successful learning results from the interaction between the transformer's layers, and uncovering evidence that the presence of the simpler unigram solution may delay formation of the final bigram solution. We examine how learning is affected by varying the prior distribution over Markov chains, and consider the generalization of our in-context learning of Markov chains (ICL-MC) task to $n$-grams for $n > 2$.

The Evolution of Statistical Induction Heads: In-Context Learning Markov Chains

TL;DR

We present a controlled study of in-context learning (ICL) by training Transformers on an ICL-MC task where each sequence is generated from a Markov chain drawn from a Dirichlet prior. The main contribution is the discovery of statistical induction heads that implement in-context bigram statistics, with learning unfolding in distinct phases from uniform predictions to unigram-based, and finally to Bayes-optimal bigram predictions. A minimal linear-transformer model corroborates the two-step gradient dynamics, showing the second layer learns first and that unigram signals can hinder rapid formation of the bigram solution; extending to -grams demonstrates generalization of the hierarchical learning phenomenon. These results illuminate mechanistic pathways for ICL in LLMs and suggest how simple priors and curriculum-like shifts shape the emergence of complex in-context algorithms, with potential implications for understanding and improving in-context reasoning in real language models.

Abstract

Large language models have the ability to generate text that mimics patterns in their inputs. We introduce a simple Markov Chain sequence modeling task in order to study how this in-context learning (ICL) capability emerges. In our setting, each example is sampled from a Markov chain drawn from a prior distribution over Markov chains. Transformers trained on this task form \emph{statistical induction heads} which compute accurate next-token probabilities given the bigram statistics of the context. During the course of training, models pass through multiple phases: after an initial stage in which predictions are uniform, they learn to sub-optimally predict using in-context single-token statistics (unigrams); then, there is a rapid phase transition to the correct in-context bigram solution. We conduct an empirical and theoretical investigation of this multi-phase process, showing how successful learning results from the interaction between the transformer's layers, and uncovering evidence that the presence of the simpler unigram solution may delay formation of the final bigram solution. We examine how learning is affected by varying the prior distribution over Markov chains, and consider the generalization of our in-context learning of Markov chains (ICL-MC) task to -grams for .
Paper Structure (34 sections, 6 theorems, 74 equations, 13 figures)

This paper contains 34 sections, 6 theorems, 74 equations, 13 figures.

Key Result

Proposition 2.2

A single-head two layer attention-only transformer can find the bigram statistics in the in-context learning markov chain task.

Figures (13)

  • Figure 1: (left) We train small transformers to perform in-context learning of Markov chains (ICL-MC)---next-token prediction on the outputs of Markov chains. Each training sequence is generated by sampling a transition matrix from a prior distribution, and then sampling a sequence from this Markov chain. (right) Distance of a transformer's output distribution to several well-defined strategies over the course of training on our in-context Markov chain task. The model passes through three stages: (1) predicting a uniform distribution, (2) predicting based on in-context unigram statistics, (3) predicting based on in-context bigram statistics. Shading is based on the minimum of the curves.
  • Figure 2: Attention for a fixed input at various time steps in training. These diagrams show where the attention heads are attending to at each layer. In the second layer, only the last token attention is shown. Tokens on top attend to tokens below them. Attention starts off uniform, but by the end of training, the layers are clearly acting the same as the induction head construction. Specifically, in the first layer each token is attending to the previous token. In the second layer, the current token, a $2$, is attending to tokens that followed $2$s, allowing bigram statistics to be calculated. Figure \ref{['fig:attn_heatmap']} shows the full attention matrices as heatmaps.
  • Figure 3: A two layer transformer (top) and a minimal model (bottom) trained on our in-context Markov Chain task. A comparison of the two layer attention-only transformer and minimal model (\ref{['eq:min-mod-def']}) (with $v$ having constant uniform initialization, and $W_K$ initialized to $0$). The graphs on the left are test loss measured by KL-Divergence from the underlying truth. The green line shows the loss of the unigram strategy, and the orange line shows the loss of the bigram strategy. The middle graph shows the effective positional encoding (for the transformer, these are for the first layer, and averaged over all tokens). The graph on the right shows the KL-divergence between the outputs of the models and three strategy. The lower the KL-divergence, the more similar the model is to that strategy.
  • Figure 4: (a) Transformers trained on different data distributions and evaluated at the original distribution. Color displays a smooth interpolation between data distributions of uninformative unigrams strategy (purple, 0) and Bayes optimal unigrams (yellow, 1). When there is not much signal from unigrams, learning progresses faster without long plateaus. See in Appendix \ref{['para:interpolate']} for a description of the data distributions. (b) Training of the minimal model in In-Context Learning Markov Chains with $k=2$ states. (left) The heatmap of the 2nd layer ($W_k$ matrix) that learns to be close to diagonal. (right) The values of the positional embeddings (1st layer) that display the curious even/odd pattern. Timestep corresponds to a phase when the model has started implementing the bigrams solution, but has not converged yet.
  • Figure 5: Three-headed transformer trained on In-Context Learning 3-grams (trigrams), with context length 200. Left: Loss during training. The model hierarchically converges close to the Bayes optimal solution. Right: KL divergence between the model and different strategies during training. As we observe, there are 4 stages of learning, each of them corresponding to a different algorithm implemented by the model.
  • ...and 8 more figures

Theorems & Definitions (11)

  • Remark 2.1
  • Proposition 2.2: Transformer Construction
  • proof
  • Lemma 3.1
  • Proposition 3.2
  • Lemma A.1
  • proof
  • Corollary A.2
  • proof
  • Proposition A.3
  • ...and 1 more