(How) Do Language Models Track State?
Belinda Z. Li, Zifan Carl Guo, Jacob Andreas
TL;DR
This work investigates how transformer language models track evolving world state by framing state tracking as permutation composition on the symmetric group $S_n$. It identifies two robust mechanisms, the associative algorithm (AA) and the parity-associative algorithm (PAA), and shows that LMs trained on permutation tasks converge to one of these two strategies rather than step-by-step simulation or fully parallel composition. Using probing and activation patching, the authors demonstrate distinct signatures for AA and PAA, including depth-dependent decoding of state and parity, attention patterns, and generalization behavior across sequence lengths. They further reveal that the mechanism actually learned depends on architecture, initialization, and carefully designed intermediate curricula, and that these findings can be steered through targeted pre-training or curriculum designs. Overall, the work shows that LMs can learn efficient, interpretable state-tracking mechanisms, and that the emergence of these mechanisms is predictable and controllable, with implications for understanding how real-world LMs model state in language and code tasks.
Abstract
Transformer language models (LMs) exhibit behaviors -- from storytelling to code generation -- that seem to require tracking the unobserved state of an evolving world. How do they do this? We study state tracking in LMs trained or fine-tuned to compose permutations (i.e., to compute the order of a set of objects after a sequence of swaps). Despite the simple algebraic structure of this problem, many other tasks (e.g., simulation of finite automata and evaluation of boolean expressions) can be reduced to permutation composition, making it a natural model for state tracking in general. We show that LMs consistently learn one of two state tracking mechanisms for this task. The first closely resembles the "associative scan" construction used in recent theoretical work by Liu et al. (2023) and Merrill et al. (2024). The second uses an easy-to-compute feature (permutation parity) to partially prune the space of outputs, and then refines this with an associative scan. LMs that learn the former algorithm tend to generalize better and converge faster, and we show how to steer LMs toward one or the other with intermediate training tasks that encourage or suppress the heuristics. Our results demonstrate that transformer LMs, whether pre-trained or fine-tuned, can learn to implement efficient and interpretable state-tracking mechanisms, and the emergence of these mechanisms can be predicted and controlled.
