Table of Contents
Fetching ...

Iteration Head: A Mechanistic Study of Chain-of-Thought

Vivien Cabannes, Charles Arnal, Wassim Bouaziz, Alice Yang, Francois Charton, Julia Kempe

TL;DR

This paper demonstrates how chain-of-Thought reasoning emerges in transformers in a controlled and interpretable setting by observing the appearance of a specialized attention mechanism dedicated to iterative reasoning, which is coined "iteration heads".

Abstract

Chain-of-Thought (CoT) reasoning is known to improve Large Language Models both empirically and in terms of theoretical approximation power. However, our understanding of the inner workings and conditions of apparition of CoT capabilities remains limited. This paper helps fill this gap by demonstrating how CoT reasoning emerges in transformers in a controlled and interpretable setting. In particular, we observe the appearance of a specialized attention mechanism dedicated to iterative reasoning, which we coined "iteration heads". We track both the emergence and the precise working of these iteration heads down to the attention level, and measure the transferability of the CoT skills to which they give rise between tasks.

Iteration Head: A Mechanistic Study of Chain-of-Thought

TL;DR

This paper demonstrates how chain-of-Thought reasoning emerges in transformers in a controlled and interpretable setting by observing the appearance of a specialized attention mechanism dedicated to iterative reasoning, which is coined "iteration heads".

Abstract

Chain-of-Thought (CoT) reasoning is known to improve Large Language Models both empirically and in terms of theoretical approximation power. However, our understanding of the inner workings and conditions of apparition of CoT capabilities remains limited. This paper helps fill this gap by demonstrating how CoT reasoning emerges in transformers in a controlled and interpretable setting. In particular, we observe the appearance of a specialized attention mechanism dedicated to iterative reasoning, which we coined "iteration heads". We track both the emergence and the precise working of these iteration heads down to the attention level, and measure the transferability of the CoT skills to which they give rise between tasks.
Paper Structure (27 sections, 4 equations, 14 figures, 1 table, 1 algorithm)

This paper contains 27 sections, 4 equations, 14 figures, 1 table, 1 algorithm.

Figures (14)

  • Figure 1: Arguably, reasoning involves updating an internal state (red) as new information is processed (green). The diagram above, where each element represents a piece of information, is an abstract depiction of this idea. This observation motivates our use of iterative tasks as a proxy for more general reasoning processes. At first glance, a limitation of transformers is their lack of an internal state, which makes it challenging to implement this diagram lecun2022path.
  • Figure 5: Implementation of an iteration head with a two-layer transformer. Contiguous box: superposition in high-dimensional space. Blue: information brought to working space thanks to residual connections. Red: information brought thanks to attention. Green: next-token prediction. The first layer MLP implements a subtraction $t = (L+t) - (L+1) + 1$ for the second attention to be able to query $p_t$ from $(p_{L+1}, p_{L+t})$. The second layer MLP implements $F$ to be able to predict $s_t$ from $(s_{t-1}, x_t)$, with the "end-of-input" mark assimilated to the initial state $s_0$ of Algorithm \ref{['alg:it']}.
  • Figure 6: Left: attention maps learned for the parity problem when processing a sequence of length $L=29$. Yellow indicates high attention score. The yellow line on the left plot shows that all the queries after the EoI token at position $t = 30$ point to the EoI token. In other terms, the first attention implements the "Are you EoI?" query of Figure \ref{['fig:iteration-head']}, while the second implements the "Are you $p_t$?" query. Right: accuracy dynamics for different sequence lengths when learning the parity problem. We observe fast learning of short sequences (we used the tab10 color scheme of Matplotlib Hunter:2007 with $L\in\{8, 11, 14, 17, \ldots , 32\}$), and characteristic staircase behaviors.
  • Figure 7: Test accuracy (where red indicates better performance) after learning the polynomial iteration task with $P(X, Y) = XY + 1$ in ${\mathbb{F}}_{11}$ for 1000 epochs. The accuracy is reported as a function of the embedding dimension (on the $y$-axis), and the maximum sequence length $L_{\max}$ (on the $x$-axis). The learning was conducted with a two-layer transformer with CoT (left), without CoT (middle), or with a one-layer transformer with CoT (right). This illustrates the usefulness of CoT and two-layer architectures.
  • Figure 8: Left: attention peakiness score after 1000 epochs of learning with the polynomial iteration task parameterized by $P(X, Y) = XY + 1$ in ${\mathbb{F}}_{11}$ as a function of the embedding dimension $d$ and the maximum sequence length $L_{\max}$. Right: example of attention maps of sub-sampled iteration heads.
  • ...and 9 more figures