Table of Contents
Fetching ...

Constrained belief updates explain geometric structures in transformer representations

Mateusz Piotrowski, Paul M. Riechers, Daniel Filan, Adam S. Shai

TL;DR

The paper investigates what computational structures emerge in transformers trained on next-token prediction and argues that they implement constrained Bayesian belief updating under architectural constraints. Using the Mess3 HMM as a tractable testbed, it shows that attention-based updates map to belief-simplex geometry, with OV vectors and token embeddings aligning with theory, and that a spectral analysis predicts when single-head versus multi-head attention is required. Across experiments, the first-layer attention performs constrained belief updates, while deeper layers progressively transform representations toward full Bayesian beliefs. The results provide a principled interpretability lens for how architectural constraints shape inference in transformers and offer insights that may generalize to larger language models.

Abstract

What computational structures emerge in transformers trained on next-token prediction? In this work, we provide evidence that transformers implement constrained Bayesian belief updating -- a parallelized version of partial Bayesian inference shaped by architectural constraints. We integrate the model-agnostic theory of optimal prediction with mechanistic interpretability to analyze transformers trained on a tractable family of hidden Markov models that generate rich geometric patterns in neural activations. Our primary analysis focuses on single-layer transformers, revealing how the first attention layer implements these constrained updates, with extensions to multi-layer architectures demonstrating how subsequent layers refine these representations. We find that attention carries out an algorithm with a natural interpretation in the probability simplex, and create representations with distinctive geometric structure. We show how both the algorithmic behavior and the underlying geometry of these representations can be theoretically predicted in detail -- including the attention pattern, OV-vectors, and embedding vectors -- by modifying the equations for optimal future token predictions to account for the architectural constraints of attention. Our approach provides a principled lens on how architectural constraints shape the implementation of optimal prediction, revealing why transformers develop specific intermediate geometric structures.

Constrained belief updates explain geometric structures in transformer representations

TL;DR

The paper investigates what computational structures emerge in transformers trained on next-token prediction and argues that they implement constrained Bayesian belief updating under architectural constraints. Using the Mess3 HMM as a tractable testbed, it shows that attention-based updates map to belief-simplex geometry, with OV vectors and token embeddings aligning with theory, and that a spectral analysis predicts when single-head versus multi-head attention is required. Across experiments, the first-layer attention performs constrained belief updates, while deeper layers progressively transform representations toward full Bayesian beliefs. The results provide a principled interpretability lens for how architectural constraints shape inference in transformers and offer insights that may generalize to larger language models.

Abstract

What computational structures emerge in transformers trained on next-token prediction? In this work, we provide evidence that transformers implement constrained Bayesian belief updating -- a parallelized version of partial Bayesian inference shaped by architectural constraints. We integrate the model-agnostic theory of optimal prediction with mechanistic interpretability to analyze transformers trained on a tractable family of hidden Markov models that generate rich geometric patterns in neural activations. Our primary analysis focuses on single-layer transformers, revealing how the first attention layer implements these constrained updates, with extensions to multi-layer architectures demonstrating how subsequent layers refine these representations. We find that attention carries out an algorithm with a natural interpretation in the probability simplex, and create representations with distinctive geometric structure. We show how both the algorithmic behavior and the underlying geometry of these representations can be theoretically predicted in detail -- including the attention pattern, OV-vectors, and embedding vectors -- by modifying the equations for optimal future token predictions to account for the architectural constraints of attention. Our approach provides a principled lens on how architectural constraints shape the implementation of optimal prediction, revealing why transformers develop specific intermediate geometric structures.

Paper Structure

This paper contains 32 sections, 23 equations, 12 figures, 1 table.

Figures (12)

  • Figure 1: Transformers' internal representations exhibit complex geometric structure matching the belief-state geometry. (A) Mess3 HMM, vertices represent hidden states with their emission distributions. (B) Ground-truth belief state geometry of Mess3. Each point represents a belief-state probability distribution over hidden states of the HMM, induced via Bayesian updates upon a sequence of observed emissions, with proximity to the vertices of the simplex corresponding to the probabilities of the three hidden states. (C) Schematic of a single-layer transformer with Intermediate activations after Attention, and Final activations after the subsequent MLP. (D) PCA projections of the model's final residual stream (left), before the unembedding, reveals a geometric representation that closely matches the belief geometry shown in (B), whereas the PCA projection of the intermediate residual stream (right) after attention but before the MLP exhibits an intricate but different structure. In (B) and (D), points are colored according to the ground-truth belief states associated with the sequence of tokens that induces the point, taking the three constituent probabilities over hidden states of the HMM as RGB values.
  • Figure 2: Intermediate Representation Construction by Attention. A transformer trained on Mess3 with $x=0.15$ and $\alpha=0.6$ exhibits intermediate representations constructed through a specific attention mechanism. (A) The OV vectors (arrows) form three distinct clusters, each corresponding to a token and positioned at the vertices of a triangle, while token embeddings (circles) are clustered near the origin. (B) Our theoretical predictions for the OV vectors (shown for all (position, token) pairs) and embeddings (for positions $>2$) align closely to those found in the trained transformer. (C) Attention patterns are primarily determined by the positional distance between the destination and source tokens, following an exponential decay described by $(1-3x)^{|n-1|}$. They are largely independent of specific token sequences. (C, inset) The theoretical (Eq. \ref{['eq:Attention_dest_relation']}) and actual values in the attention pattern align closely. (D) Construction of intermediate representations for five input subsequences of increasing length (from the example sequence $01120$, shown left to right). The attention mechanism builds the fractal by taking linear combinations of the three $\vec{v}_\text{s}$ vectors. The colored vectors illustrate the components of the sum for each example subsequence, while the gray dots represent all possible vector sums for all sequences at that position.
  • Figure 3: Attention heads combine to capture oscillatory dynamics in belief updating. (A) In the token embedding space, the model uses each attention head to embed tokens on opposite poles of the simplex. (B) The attention patterns of the two heads (shown here averaged over all sequences) act as positive and negative components. When combined, they produce the oscillatory pattern predicted by the exponentiated eigenvalue $\zeta^n = (1 - 3x)^n = (-1)^n(3x-1)^n$.
  • Figure 4: Comparison of model representations and theoretical predictions for different Mess3 hyperparameters in each row. Each subfigure shows four columns: (i) Intermediate representation from Eq. \ref{['eq:constrained-belief']}. (ii) PCA projection of the model activations in the intermediate layer. (iii) Ground truth belief state geometry from Eq. \ref{['eq:full-belief']}. (iv) PCA projection of the final activations after the MLP.
  • Figure A1: Diagrammatic visualization comparing the recurrent nature of Optimal Bayesian Inference (top), to the parallel attention mechanism (middle), and the parallel Constrained Belief Updating presented in this paper.
  • ...and 7 more figures