Table of Contents
Fetching ...

Multi-step Predictive Coding Leads To Simplicity Bias

Aviv Ratzon, Omri Barak

TL;DR

This work investigates when predictive coding yields interpretable latent representations by analyzing a minimal linear predictive coding problem with horizon $A$ and depth $L$, showing that the leading singular direction of the OLS estimator $oldsymbol{X}^ op oldsymbol{X}$ dominates as $A$ grows with the environment size $S$. The authors combine gradient-descent dynamics with an OLS-based bias analysis to explain why deeper networks trained with multi-step horizons converge to structured, low-rank encodings of the latent state, contrasting with single-step training. They extend these insights to nonlinear and more naturalistic settings, including piecewise-linear environments and MNIST-based tasks, where multi-step prediction yields a low-dimensional manifold ordered by the latent variable, while regularization fails to recover such structure. The findings offer a principled account of when predictive coding yields interpretable world models and inform how horizon, depth, and training dynamics shape learned representations in both artificial and biological systems.

Abstract

Predictive coding is a framework for understanding the formation of low-dimensional internal representations mirroring the environment's latent structure. The conditions under which such representations emerge remain unclear. In this work, we investigate how the prediction horizon and network depth shape the solutions of predictive coding tasks. Using a minimal abstract setting inspired by prior work, we show empirically and theoretically that sufficiently deep networks trained with multi-step prediction horizons consistently recover the underlying latent structure, a phenomenon explained through the Ordinary Least Squares estimator structure and biases in learning dynamics. We then extend these insights to nonlinear networks and complex datasets, including piecewise linear functions, MNIST, multiple latent states and higher dimensional state geometries. Our results provide a principled understanding of when and why predictive coding induces structured representations, bridging the gap between empirical observations and theoretical foundations.

Multi-step Predictive Coding Leads To Simplicity Bias

TL;DR

This work investigates when predictive coding yields interpretable latent representations by analyzing a minimal linear predictive coding problem with horizon and depth , showing that the leading singular direction of the OLS estimator dominates as grows with the environment size . The authors combine gradient-descent dynamics with an OLS-based bias analysis to explain why deeper networks trained with multi-step horizons converge to structured, low-rank encodings of the latent state, contrasting with single-step training. They extend these insights to nonlinear and more naturalistic settings, including piecewise-linear environments and MNIST-based tasks, where multi-step prediction yields a low-dimensional manifold ordered by the latent variable, while regularization fails to recover such structure. The findings offer a principled account of when predictive coding yields interpretable world models and inform how horizon, depth, and training dynamics shape learned representations in both artificial and biological systems.

Abstract

Predictive coding is a framework for understanding the formation of low-dimensional internal representations mirroring the environment's latent structure. The conditions under which such representations emerge remain unclear. In this work, we investigate how the prediction horizon and network depth shape the solutions of predictive coding tasks. Using a minimal abstract setting inspired by prior work, we show empirically and theoretically that sufficiently deep networks trained with multi-step prediction horizons consistently recover the underlying latent structure, a phenomenon explained through the Ordinary Least Squares estimator structure and biases in learning dynamics. We then extend these insights to nonlinear networks and complex datasets, including piecewise linear functions, MNIST, multiple latent states and higher dimensional state geometries. Our results provide a principled understanding of when and why predictive coding induces structured representations, bridging the gap between empirical observations and theoretical foundations.

Paper Structure

This paper contains 27 sections, 19 equations, 10 figures.

Figures (10)

  • Figure 1: Illustration of multi-step predictive coding setting where time is abstracted away. An agent is acting in an environment, producing a set of observations and actions in its trajectory. The task is to predict, for each action and observation pair, the following observation. The environment has an underlying structure, and training a model on a predictive coding task sometimes generates a representation of this latent structure. Recent work has shown that increasing the prediction horizon can lead to more accurate and stable representations levenstein2024sequential.
  • Figure 2: Top: the first two principal components of hidden activations for networks trained on single-step prediction (left) and multi-step prediction with maximal action $A = S/2$ (right). Bottom: quantitative metrics across values of $A$ and network depth $L$. Left—NC1 decreases with increasing $A$ and $L$, indicating more compact class clusters. Middle—normalized margins (relative to $L=2$) decrease with depth. Right—representations become increasingly aligned with the target state as $A$ and $L$ grow. See the appendix for detailed metric definitions. The analysis was done for $|a|<=1$.
  • Figure 3: Analyzing the singular values and vectors from the OLS estimator and the model's effective weight matrix. As can be seen, for larger $A$ both become lower dimensional, and the leading singular vector becomes the transformation from the input state and action to the output state. For $A=1$, since there is no strong direction that explains most variance, the model's singular vectors are mostly decoupled from those of the OLS estimator. We also a comparison between a shallow network ($L=2$) and a deep network ($L=9$). As can be seen in the rightmost column, shallow networks depends on directions of the input space to classify the data.
  • Figure 4: Results for the task with two independent environments. Data are projected onto the first two singular vectors of the OLS estimator, confirming that observations and actions from each environment are orthogonal. In the single-step setting, the model representations remain unaligned, whereas in the multi-step setting they collapse into a shared low-rank structure that aligns the two environments and reveals their underlying symmetry.
  • Figure 5: Training a deep nonlinear network on a predictive task with observations generated from a piecewise linear function containing three discontinuities. When the action distribution is narrow, the learned representations primarily mirror local autocorrelation. In contrast, with a wider action distribution, the network organizes its hidden representations along a smooth one-dimensional manifold that bridges the discontinuities, thereby recovering the underlying latent state. The top figure shows the PCA space, and the bottom figure shows the distance matrix sorted by the state variable.
  • ...and 5 more figures