Table of Contents
Fetching ...

Transformers Use Causal World Models in Maze-Solving Tasks

Alex F. Spies, William Edwards, Michael I. Ivanitskiy, Adrians Skapars, Tilman Räuker, Katsumi Inoue, Alessandra Russo, Murray Shanahan

TL;DR

Transformers solving maze tasks develop structured world models that reflect the environment's connectivity. The authors combine early-layer connectivity-attention analysis with sparse autoencoders (SAEs) trained on the residual stream to identify causal, disentangled WM features and validate them through targeted interventions. They show an asymmetry in interventions, with activating features more effective than removing them, and reveal a compositional code in the latent WM representations influenced by positional embeddings. The results demonstrate that SAEs reveal WM features missed by linear probes, enabling steerability of sequential planners and offering new insights for interpretability and safety in AI systems.

Abstract

Recent studies in interpretability have explored the inner workings of transformer models trained on tasks across various domains, often discovering that these networks naturally develop highly structured representations. When such representations comprehensively reflect the task domain's structure, they are commonly referred to as "World Models" (WMs). In this work, we identify WMs in transformers trained on maze-solving tasks. By using Sparse Autoencoders (SAEs) and analyzing attention patterns, we examine the construction of WMs and demonstrate consistency between SAE feature-based and circuit-based analyses. By subsequently intervening on isolated features to confirm their causal role, we find that it is easier to activate features than to suppress them. Furthermore, we find that models can reason about mazes involving more simultaneously active features than they encountered during training; however, when these same mazes (with greater numbers of connections) are provided to models via input tokens instead, the models fail. Finally, we demonstrate that positional encoding schemes appear to influence how World Models are structured within the model's residual stream.

Transformers Use Causal World Models in Maze-Solving Tasks

TL;DR

Transformers solving maze tasks develop structured world models that reflect the environment's connectivity. The authors combine early-layer connectivity-attention analysis with sparse autoencoders (SAEs) trained on the residual stream to identify causal, disentangled WM features and validate them through targeted interventions. They show an asymmetry in interventions, with activating features more effective than removing them, and reveal a compositional code in the latent WM representations influenced by positional embeddings. The results demonstrate that SAEs reveal WM features missed by linear probes, enabling steerability of sequential planners and offering new insights for interpretability and safety in AI systems.

Abstract

Recent studies in interpretability have explored the inner workings of transformer models trained on tasks across various domains, often discovering that these networks naturally develop highly structured representations. When such representations comprehensively reflect the task domain's structure, they are commonly referred to as "World Models" (WMs). In this work, we identify WMs in transformers trained on maze-solving tasks. By using Sparse Autoencoders (SAEs) and analyzing attention patterns, we examine the construction of WMs and demonstrate consistency between SAE feature-based and circuit-based analyses. By subsequently intervening on isolated features to confirm their causal role, we find that it is easier to activate features than to suppress them. Furthermore, we find that models can reason about mazes involving more simultaneously active features than they encountered during training; however, when these same mazes (with greater numbers of connections) are provided to models via input tokens instead, the models fail. Finally, we demonstrate that positional encoding schemes appear to influence how World Models are structured within the model's residual stream.

Paper Structure

This paper contains 19 sections, 1 equation, 26 figures, 3 tables.

Figures (26)

  • Figure 1: Overview of our methodology for discovering and validating world models in transformer-based maze solvers. (A) We analyze attention patterns in early layers, finding heads that consolidate maze connectivity information at semicolon tokens. (B) We train sparse autoencoders on the residual stream immediately following the first block, identifying interpretable features that encode maze connectivity. (C) We demonstrate the causal role of the world models in our transformers comparing the features extracted through both methods and validating them through causal interventions.
  • Figure 2: Tokenization scheme and visualization of a shortest-path maze task generated using ivanitskiy2023mazedataset.
  • Figure 3: Attention values for heads L0H3, L0H5, and L0H7 in Stan. We use a rather nonstandard representation, looking only at a fixed window into the past of which tokens are attended to by semicolon tokens. Every $4$th position, up to $140$, is shown along the $x$-axis. Color shows attention to positions 1, 3, 5, and 7 earlier in the context (shown along the $y$-axis), for an example 6x6 maze input. This sort of pattern is typical across all inputs examined. Up until context position 100, the heads are attending 1 and 3 positions back; after this the pattern shifts to 5 and 7 back. Note the complementary attention patterns of L0H3 and L0H7. Closer examination shows that L0H3 prefers to direct its attention to 'even-parity' maze cells, with L0H7 preferring 'odd-parity' cells. L0H5 more frequently splits its attention between 1 and 3 back, but sometimes 'fills in' for L0H7. The origins of this pattern are explored further in appendix \ref{['app:stan-qk-circuit']}; note also the similarities to \ref{['fig:s3:stan_token_OV_magnitudes']}. The other five heads in L0 show no similar pattern. Full patterns are shown in \ref{['fig:s3:ST_attn_full']}
  • Figure 4: Magnitudes of vectors resulting from applying the $W_{OV}$ matrices of heads L0H3, L0H5 and L0H7 of Stan to maze-cell token embeddings, projected onto the maze grid. The pattern here mirrors the way that the heads divide their attention between the 1-back and 3-back context positions (exemplified in \ref{['fig:s3:stan_L0_attn']}) with L0H3 focused on 'even-parity' cells, and L0H7 and LH05 focused primarily on 'odd-parity' cells. This pattern also recurs in the overlaps between query and key vectors of token embeddings, explored in detail in Appendix \ref{['app:stan-qk-circuit']}.
  • Figure 5: Magnitudes of vectors resulting from applying the $W_{OV}$ matrices of layer-0 heads of Terry to maze-cell token embeddings, projected onto the maze grid. The pattern here is much less striking than that for Stan (shown in \ref{['fig:s3:stan_token_OV_magnitudes']}) although it does suggest that the heads specialise in even/odd-parity cells in localised regions of the maze.
  • ...and 21 more figures