Table of Contents
Fetching ...

Mechanism and Emergence of Stacked Attention Heads in Multi-Layer Transformers

Tiberiu Musat

TL;DR

The paper defines the Retrieval Problem as a depth-dependent reasoning task that requires at least $t \,\ge\, \log_3(2D)$ transformer layers to solve, establishing a theoretical lower bound on network depth. It demonstrates that large language models can solve this problem via prompting without fine-tuning across formulations, and introduces a minimal problem formulation to study the learned circuits. Through reverse-engineering of attention maps, the work reveals induction-head–like retrieval heads and shows that these heads emerge in a specific sequence under an implicit curriculum. The findings connect depth, curriculum-driven learning, and emergent reasoning abilities, offering insights into how complex multi-head circuits arise and how they might generalize to natural language tasks and safety-critical AI systems.

Abstract

In this paper, I introduce the retrieval problem, a simple yet common reasoning task that can be solved only by transformers with a minimum number of layers, which grows logarithmically with the input size. I empirically show that large language models can solve the task under different prompting formulations without any fine-tuning. To understand how transformers solve the retrieval problem, I train several transformers on a minimal formulation. Successful learning occurs only under the presence of an implicit curriculum. I uncover the learned mechanisms by studying the attention maps in the trained transformers. I also study the training process, uncovering that attention heads always emerge in a specific sequence guided by the implicit curriculum.

Mechanism and Emergence of Stacked Attention Heads in Multi-Layer Transformers

TL;DR

The paper defines the Retrieval Problem as a depth-dependent reasoning task that requires at least transformer layers to solve, establishing a theoretical lower bound on network depth. It demonstrates that large language models can solve this problem via prompting without fine-tuning across formulations, and introduces a minimal problem formulation to study the learned circuits. Through reverse-engineering of attention maps, the work reveals induction-head–like retrieval heads and shows that these heads emerge in a specific sequence under an implicit curriculum. The findings connect depth, curriculum-driven learning, and emergent reasoning abilities, offering insights into how complex multi-head circuits arise and how they might generalize to natural language tasks and safety-critical AI systems.

Abstract

In this paper, I introduce the retrieval problem, a simple yet common reasoning task that can be solved only by transformers with a minimum number of layers, which grows logarithmically with the input size. I empirically show that large language models can solve the task under different prompting formulations without any fine-tuning. To understand how transformers solve the retrieval problem, I train several transformers on a minimal formulation. Successful learning occurs only under the presence of an implicit curriculum. I uncover the learned mechanisms by studying the attention maps in the trained transformers. I also study the training process, uncovering that attention heads always emerge in a specific sequence guided by the implicit curriculum.

Paper Structure

This paper contains 36 sections, 1 equation, 8 figures.

Figures (8)

  • Figure 1: Illustrative examples of retrieval and conditional retrieval questions.
  • Figure 2: Accuracy of large language models on the retrieval and conditional retrieval problems. Dashed lines indicate the accuracy of random guessing. Full prompts and benchmarking details are provided in Appendix \ref{['prompts']}.
  • Figure 3: Positions that contain shared information before any transformer layers in the case of $D=5$. Top edges denote shared token embeddings. Bottom edges denote shared positional encodings.
  • Figure 4: Final validation loss by number of layers, averaged across multiple runs. Left: IC vs. non-IC formulations. Right: Partial validation loss for each position in the retrieval chains (IC only).
  • Figure 5: Reverse-engineered circuits from three 12-layer transformers trained on the retrieval problem with $D=3$ and IC.
  • ...and 3 more figures