Table of Contents
Fetching ...

Grid-World Representations in Transformers Reflect Predictive Geometry

Sasha Brenner, Thomas R. Knösche, Nico Scherf

Abstract

Next-token predictors often appear to develop internal representations of the latent world and its rules. The probabilistic nature of these models suggests a deep connection between the structure of the world and the geometry of probability distributions. In order to understand this link more precisely, we use a minimal stochastic process as a controlled setting: constrained random walks on a two-dimensional lattice that must reach a fixed endpoint after a predetermined number of steps. Optimal prediction of this process solely depends on a sufficient vector determined by the walker's position relative to the target and the remaining time horizon; in other words, the probability distributions are parametrized by the world's geometry. We train decoder-only transformers on prefixes sampled from the exact distribution of these walks and compare their hidden activations to the analytically derived sufficient vectors. Across models and layers, the learned representations align strongly with the ground-truth predictive vectors and are often low-dimensional. This provides a concrete example in which world-model-like representations can be directly traced back to the predictive geometry of the data itself. Although demonstrated in a simplified toy system, the analysis suggests that geometric representations supporting optimal prediction may provide a useful lens for studying how neural networks internalize grammatical and other structural constraints.

Grid-World Representations in Transformers Reflect Predictive Geometry

Abstract

Next-token predictors often appear to develop internal representations of the latent world and its rules. The probabilistic nature of these models suggests a deep connection between the structure of the world and the geometry of probability distributions. In order to understand this link more precisely, we use a minimal stochastic process as a controlled setting: constrained random walks on a two-dimensional lattice that must reach a fixed endpoint after a predetermined number of steps. Optimal prediction of this process solely depends on a sufficient vector determined by the walker's position relative to the target and the remaining time horizon; in other words, the probability distributions are parametrized by the world's geometry. We train decoder-only transformers on prefixes sampled from the exact distribution of these walks and compare their hidden activations to the analytically derived sufficient vectors. Across models and layers, the learned representations align strongly with the ground-truth predictive vectors and are often low-dimensional. This provides a concrete example in which world-model-like representations can be directly traced back to the predictive geometry of the data itself. Although demonstrated in a simplified toy system, the analysis suggests that geometric representations supporting optimal prediction may provide a useful lens for studying how neural networks internalize grammatical and other structural constraints.
Paper Structure (43 sections, 34 equations, 3 figures, 1 table)

This paper contains 43 sections, 34 equations, 3 figures, 1 table.

Figures (3)

  • Figure 1: Transformer layer representations closely resemble the ground-truth predictive vectors of Equation \ref{['eq:next_step_probs_lin']} (as computed in Appendix \ref{['app:suff_vec']}). The left column corresponds to the ground-truth principal components (PCs), and the right column shows the PCs of the last layer norm representation, after Procrustes-alignment to the ground-truth vectors. Each point represents a distinct sequence of movements on the square lattice, colored by the number of steps (path length). These representations could be interpreted as grid-world models, and they simply correspond to minimal and sufficient vectors for prediction.
  • Figure 2: Linear similarity metrics between layer activations and ground-truth predictive vectors. (A) and (D): Both $R^2$ and $\rm{lCKA}$ are high and tend to increase with layer depth, with LayerNorms generally exhibiting stronger similarity. Moreover, walkers with longer time horizons $T$ have worse similarity scores given an endpoint, while loopers' scores are higher in comparison to shifted-endpoint walkers' ones for a given time horizon. In these plots, lines corresponding to loopers are solid, while those associated with shifted-endpoint walkers are dashed. (B-C): The $R^2$ value for the $T=20$ looper is shown as a function of the training step in (B), and it is compared with the excess validation loss from its corresponding timestep in (C). (E-F) The $\rm{lCKA}$ value for the same $T=20$ looper is also shown as a function of training step (E) and compared to the loss (F).
  • Figure 3: For the $T=20$ looper, the number of principal components required to explain 99% of the variance converges to 2 for the final LayerNorm's activations, but it rebounds for all previous hidden layers after around $10^4$ training steps.