Table of Contents
Fetching ...

Transformers represent belief state geometry in their residual stream

Adam S. Shai, Sarah E. Marzen, Lucas Teixeira, Alexander Gietelink Oldenziel, Paul M. Riechers

TL;DR

The paper develops a theory that transformer residual streams encode the meta-dynamics of belief updating over hidden data-generating states, formalized via the mixed-state presentation (MSP) of a hidden Markov model. It demonstrates that belief states lie in a probability simplex and can be recovered by a linear projection from residual activations, even when the geometry is fractal. Through controlled experiments on MSP-generated data (e.g., Mess3 and RRXOR), the authors show that the belief-state geometry is often linearly represented in the residual stream and can be distributed across layers, providing information beyond local next-token prediction. These findings connect the data-generating process to activation geometry, with implications for interpretability, minimal architectural requirements, and understanding how large language models implement predictive inference.

Abstract

What computational structure are we building into large language models when we train them on next-token prediction? Here, we present evidence that this structure is given by the meta-dynamics of belief updating over hidden states of the data-generating process. Leveraging the theory of optimal prediction, we anticipate and then find that belief states are linearly represented in the residual stream of transformers, even in cases where the predicted belief state geometry has highly nontrivial fractal structure. We investigate cases where the belief state geometry is represented in the final residual stream or distributed across the residual streams of multiple layers, providing a framework to explain these observations. Furthermore we demonstrate that the inferred belief states contain information about the entire future, beyond the local next-token prediction that the transformers are explicitly trained on. Our work provides a general framework connecting the structure of training data to the geometric structure of activations inside transformers.

Transformers represent belief state geometry in their residual stream

TL;DR

The paper develops a theory that transformer residual streams encode the meta-dynamics of belief updating over hidden data-generating states, formalized via the mixed-state presentation (MSP) of a hidden Markov model. It demonstrates that belief states lie in a probability simplex and can be recovered by a linear projection from residual activations, even when the geometry is fractal. Through controlled experiments on MSP-generated data (e.g., Mess3 and RRXOR), the authors show that the belief-state geometry is often linearly represented in the residual stream and can be distributed across layers, providing information beyond local next-token prediction. These findings connect the data-generating process to activation geometry, with implications for interpretability, minimal architectural requirements, and understanding how large language models implement predictive inference.

Abstract

What computational structure are we building into large language models when we train them on next-token prediction? Here, we present evidence that this structure is given by the meta-dynamics of belief updating over hidden states of the data-generating process. Leveraging the theory of optimal prediction, we anticipate and then find that belief states are linearly represented in the residual stream of transformers, even in cases where the predicted belief state geometry has highly nontrivial fractal structure. We investigate cases where the belief state geometry is represented in the final residual stream or distributed across the residual streams of multiple layers, providing a framework to explain these observations. Furthermore we demonstrate that the inferred belief states contain information about the entire future, beyond the local next-token prediction that the transformers are explicitly trained on. Our work provides a general framework connecting the structure of training data to the geometric structure of activations inside transformers.
Paper Structure (30 sections, 6 equations, 9 figures)

This paper contains 30 sections, 6 equations, 9 figures.

Figures (9)

  • Figure 1: (Top) Given a hidden data-generating structure, our framework predicts a unique belief state geometry in a probability simplex. Often these have highly nontrivial fractal structure as shown in this example. (Bottom) Our main experimental result is that we find that the fractal geometry of optimal beliefs is linearly embedded in the residual stream, and emerges over the course of training.
  • Figure 2: An illustration of a hidden Markov model (HMM) and its components. The left side shows the HMM with states $\texttt{S}_\texttt{0}$, $\texttt{S}_\texttt{1}$, and $\texttt{S}_\texttt{R}$, and their respective transition probabilities. The right side displays the transition matrices $T^{(\texttt{0})}$ and $T^{(\texttt{1})}$ corresponding to token emissions $\texttt{0}$ and $\texttt{1}$. Example training data is provided at the bottom, demonstrating a sequence generated by the HMM.
  • Figure 3: (A) An example generative structure called the zero-one-random process (Z1R), since it generates data of the form ...01R01R01R... where R is a random bit. (B) The generative structure implies a unique metadynamic over belief states as a predictor synchronizes to the hidden state of the world as it observes more context. This predictive structure is called the mixed-state presentation (MSP). We label belief states with $\boldsymbol{\eta}_w$ where $w$ is the shortest string of emissions which leads to that belief state. (C) Belief states are distributions over generator states and can be embedded in a probability simplex. To read off this distribution for a given belief state, $\boldsymbol{\eta}$, one measures the perpendicular distance from the point to each edge of the simplex, shown as blue, green, and red lines in this example. These distances directly give the probabilities for each state. Thus, the vertices represent states of certainty over one generator state, since the perpendicular distance is nonzero to only one of the edges of the simplex. (D) Plotting the belief state distributions in the probability simplex gives the belief state geometry.
  • Figure 4: To verify if transformers represent belief state geometries in their residual streams, we record (A) residual stream activations at all context window positions over all inputs. (B) These activations live in a high dimensional space. (C) Each input has a ground-truth optimal belief state, which is a probability distribution over states of the data-generating process. In this way we can label, or color, each activation by the ground-truth belief associated with the input. (D) Using linear regression we then find a linear subspace of the activation space that best preserves the belief state geometry of the simplex.
  • Figure 5: The residual stream of trained transformers linearly represents the belief state geometry of the mixed-state presentation. (A) The Mess3 Process has 3 hidden states and generates sequences in a token vocabulary of $\{\texttt{A}, \texttt{B},\texttt{C}\}$. (B) The ground truth belief state geometry of the Mess3 Process has intricate fractal structure. Each point in this plot is a belief state---a probability distribution over the hidden states of the Mess3 Process. Points are colored by taking the belief probability distribution and using them as RGB values. (C) We find a linear projection of the final residual stream activations contains a representation of the ground-truth belief geometry. Points are colored according to the ground-truth belief states.
  • ...and 4 more figures