Table of Contents
Fetching ...

(How) Do Language Models Track State?

Belinda Z. Li, Zifan Carl Guo, Jacob Andreas

TL;DR

This work investigates how transformer language models track evolving world state by framing state tracking as permutation composition on the symmetric group $S_n$. It identifies two robust mechanisms, the associative algorithm (AA) and the parity-associative algorithm (PAA), and shows that LMs trained on permutation tasks converge to one of these two strategies rather than step-by-step simulation or fully parallel composition. Using probing and activation patching, the authors demonstrate distinct signatures for AA and PAA, including depth-dependent decoding of state and parity, attention patterns, and generalization behavior across sequence lengths. They further reveal that the mechanism actually learned depends on architecture, initialization, and carefully designed intermediate curricula, and that these findings can be steered through targeted pre-training or curriculum designs. Overall, the work shows that LMs can learn efficient, interpretable state-tracking mechanisms, and that the emergence of these mechanisms is predictable and controllable, with implications for understanding how real-world LMs model state in language and code tasks.

Abstract

Transformer language models (LMs) exhibit behaviors -- from storytelling to code generation -- that seem to require tracking the unobserved state of an evolving world. How do they do this? We study state tracking in LMs trained or fine-tuned to compose permutations (i.e., to compute the order of a set of objects after a sequence of swaps). Despite the simple algebraic structure of this problem, many other tasks (e.g., simulation of finite automata and evaluation of boolean expressions) can be reduced to permutation composition, making it a natural model for state tracking in general. We show that LMs consistently learn one of two state tracking mechanisms for this task. The first closely resembles the "associative scan" construction used in recent theoretical work by Liu et al. (2023) and Merrill et al. (2024). The second uses an easy-to-compute feature (permutation parity) to partially prune the space of outputs, and then refines this with an associative scan. LMs that learn the former algorithm tend to generalize better and converge faster, and we show how to steer LMs toward one or the other with intermediate training tasks that encourage or suppress the heuristics. Our results demonstrate that transformer LMs, whether pre-trained or fine-tuned, can learn to implement efficient and interpretable state-tracking mechanisms, and the emergence of these mechanisms can be predicted and controlled.

(How) Do Language Models Track State?

TL;DR

This work investigates how transformer language models track evolving world state by framing state tracking as permutation composition on the symmetric group . It identifies two robust mechanisms, the associative algorithm (AA) and the parity-associative algorithm (PAA), and shows that LMs trained on permutation tasks converge to one of these two strategies rather than step-by-step simulation or fully parallel composition. Using probing and activation patching, the authors demonstrate distinct signatures for AA and PAA, including depth-dependent decoding of state and parity, attention patterns, and generalization behavior across sequence lengths. They further reveal that the mechanism actually learned depends on architecture, initialization, and carefully designed intermediate curricula, and that these findings can be steered through targeted pre-training or curriculum designs. Overall, the work shows that LMs can learn efficient, interpretable state-tracking mechanisms, and that the emergence of these mechanisms is predictable and controllable, with implications for understanding how real-world LMs model state in language and code tasks.

Abstract

Transformer language models (LMs) exhibit behaviors -- from storytelling to code generation -- that seem to require tracking the unobserved state of an evolving world. How do they do this? We study state tracking in LMs trained or fine-tuned to compose permutations (i.e., to compute the order of a set of objects after a sequence of swaps). Despite the simple algebraic structure of this problem, many other tasks (e.g., simulation of finite automata and evaluation of boolean expressions) can be reduced to permutation composition, making it a natural model for state tracking in general. We show that LMs consistently learn one of two state tracking mechanisms for this task. The first closely resembles the "associative scan" construction used in recent theoretical work by Liu et al. (2023) and Merrill et al. (2024). The second uses an easy-to-compute feature (permutation parity) to partially prune the space of outputs, and then refines this with an associative scan. LMs that learn the former algorithm tend to generalize better and converge faster, and we show how to steer LMs toward one or the other with intermediate training tasks that encourage or suppress the heuristics. Our results demonstrate that transformer LMs, whether pre-trained or fine-tuned, can learn to implement efficient and interpretable state-tracking mechanisms, and the emergence of these mechanisms can be predicted and controlled.

Paper Structure

This paper contains 59 sections, 10 equations, 19 figures, 2 tables, 3 algorithms.

Figures (19)

  • Figure 1: We use permutation word problems as a simple model of state tracking. Actions are permutations, and states are the products of those permutations; the current state can be tracked by taking the cumulative product from left to right (§\ref{['sec:preliminary']}). We identify several possible algorithms that Transformers may use to solve permutation word problems: sequential, parallel, associative, and parity-associative (§\ref{['sec:algorithms']}). Above, we depict the "signatures" of each algorithm under two types of interpretability analysis: prefix patching, where we create pairs of prompts differing only on the first token, then substitute all activation except the prefix up to a token at a particular layer, and probing, where we train a linear probe to map from last-token representations across the layers to either the final state or the final state parity (§\ref{['sec:methods']}). Note: the dotted lines indicate two different probing signatures consistent with this algorithm (see \ref{['app:AA_parity']} for more details).
  • Figure 2: Activation patching on the residual stream for various Pythia models trained on $S_3$ and $S_5$. Each cell at layer $l$ and token $t$ represents the probability of the correct final state when the entire prefix up to $t$ at layer $l$ is restored. We find signatures matching the AA and PAA algorithms from \ref{['fig:teaser']}, with both models ignoring exponentially longer prefixes as we traverse down the layers, and PAA models containing intermediate representations that encode some information about the final state, but not its parity.
  • Figure 3: Accuracy of state probe and state parity probe across layers on $S_3$ and $S_5$ models sometimes match signatures for AA, and sometimes PAA. In all models, the state probe accuracy increases roughly exponentially with model depth. We find that in PAA models, the parity of the state is linearly decodable from earlier intermediate layers, while in the AA models shown above, the parity is never linearly encoded in any layer of the model. (In other AA models, the parity can only be linearly decoded at the final layer.)
  • Figure 4: In models that learn PAA on $S_3$, representations of the final product can be geometrically decomposed into two orthogonal directions, corresponding to the parity of the product (represented as the Z-axis in the above graph) and cluster identity of the product (represented by the X-Y plane). Note that the clusters are at 60 degrees to each other, and products of different parities within a cluster are equidistant from each other, with odd-parity products in one plane, and even-parity products in another plane.
  • Figure 5: Generalization curves showing state and parity prediction accuracy as sequence lengths vary. Models are trained on length-100 sequences and asked to generalize to varying lengths of sequences. We plot generalization curves for AA and PAA models on $S_3$ and $S_5$. In each plot, we show the 98% cutoff threshold, the sequence length at which accuracy dips below 98%. In the models that learned PAA, the parity cutoff is larger than the state cutoff, while in models that learned AA, the parity cutoff equals the state cutoff. Generally speaking, models that learned AA generalize better than ones that learned PAA.
  • ...and 14 more figures

Theorems & Definitions (1)

  • Definition 6.1