Table of Contents
Fetching ...

Measuring In-Context Computation Complexity via Hidden State Prediction

Vincent Herrmann, Róbert Csordás, Jürgen Schmidhuber

TL;DR

The paper tackles the problem of identifying when neural sequence models engage in genuinely interesting in-context computation, arguing that traditional next-token loss is insufficient. It introduces the Prediction of Hidden States (PHi) layer, an information-bottleneck mechanism that predicts future hidden states and yields a KL-based loss $L_{\text{PHi}}$ to quantify novel information in in-context computations. By jointly training or inserting PHi into pretrained models, the authors show that $L_{\text{PHi}}$ correlates with task complexity across diverse settings, including in-context language learning with PFAs, mathematical reasoning, and reasoning chains in GSM-8k and MATH. The approach provides a flexible, architecture-agnostic tool for diagnosing and potentially guiding non-trivial in-context reasoning, with implications for intrinsic motivation and self-supervised objective design in AI systems.

Abstract

Detecting when a neural sequence model does "interesting" computation is an open problem. The next token prediction loss is a poor indicator: Low loss can stem from trivially predictable sequences that are uninteresting, while high loss may reflect unpredictable but also irrelevant information that can be ignored by the model. We propose a better metric: measuring the model's ability to predict its own future hidden states. We show empirically that this metric -- in contrast to the next token prediction loss -- correlates with the intuitive interestingness of the task. To measure predictability, we introduce the architecture-agnostic "prediction of hidden states" (PHi) layer that serves as an information bottleneck on the main pathway of the network (e.g., the residual stream in Transformers). We propose a novel learned predictive prior that enables us to measure the novel information gained in each computation step, which serves as our metric. We show empirically that our metric predicts the description length of formal languages learned in-context, the complexity of mathematical reasoning problems, and the correctness of self-generated reasoning chains.

Measuring In-Context Computation Complexity via Hidden State Prediction

TL;DR

The paper tackles the problem of identifying when neural sequence models engage in genuinely interesting in-context computation, arguing that traditional next-token loss is insufficient. It introduces the Prediction of Hidden States (PHi) layer, an information-bottleneck mechanism that predicts future hidden states and yields a KL-based loss to quantify novel information in in-context computations. By jointly training or inserting PHi into pretrained models, the authors show that correlates with task complexity across diverse settings, including in-context language learning with PFAs, mathematical reasoning, and reasoning chains in GSM-8k and MATH. The approach provides a flexible, architecture-agnostic tool for diagnosing and potentially guiding non-trivial in-context reasoning, with implications for intrinsic motivation and self-supervised objective design in AI systems.

Abstract

Detecting when a neural sequence model does "interesting" computation is an open problem. The next token prediction loss is a poor indicator: Low loss can stem from trivially predictable sequences that are uninteresting, while high loss may reflect unpredictable but also irrelevant information that can be ignored by the model. We propose a better metric: measuring the model's ability to predict its own future hidden states. We show empirically that this metric -- in contrast to the next token prediction loss -- correlates with the intuitive interestingness of the task. To measure predictability, we introduce the architecture-agnostic "prediction of hidden states" (PHi) layer that serves as an information bottleneck on the main pathway of the network (e.g., the residual stream in Transformers). We propose a novel learned predictive prior that enables us to measure the novel information gained in each computation step, which serves as our metric. We show empirically that our metric predicts the description length of formal languages learned in-context, the complexity of mathematical reasoning problems, and the correctness of self-generated reasoning chains.

Paper Structure

This paper contains 27 sections, 8 equations, 19 figures.

Figures (19)

  • Figure 1: Interesting tasks, like in-context learning, or modeling of code and literature, exhibit high hidden state prediction (PHi) loss, while boring or trivial tasks, such as retrieving memorized sequences or modeling random structureless data, show low PHi loss. Next token loss provides no meaningful insight into task complexity. Results for a specialized transformer model (blue) and a pre-trained LLM (green), with PHi loss scales differing due to hidden state size. See Sections \ref{['subsec:exp_boring_interesting']} and \ref{['sec:exp_llama_tasks']} for details.
  • Figure 2: The structure of our PHi layer. It can be inserted in the middle of any next token prediction architecture, and it reconstructs its hidden states through an information bottleneck. It consists of posterior encoder $q_\psi(z_t \mid h_t)$ that predicts the latent code $z_t$, decoder $a_\xi$ used to reconstruct the hidden state $h'_t \;=\; a_\xi\!\bigl( z_t \bigr)$, and a learned autoregressive prior $p_\chi(z_t \mid z_1,\dots,z_{t-1})$ predicting $z_t$ from the past latent codes. We propose to use the KL divergence between the posterior and the prior to quantify the complexity of the "in-context computation" performed by the model.
  • Figure 3: Next-token prediction loss on each of the four tasks, for both the Transformer and the LSTM. Memorized tasks yield the lowest loss, random is the highest, and the in-context language learning task is intermediate. Bootstrapped mean with 95% confidence intervals across 10 runs.
  • Figure 4: PHi loss, relative to the performance on the memorized sequences, for the same four tasks. In-context language learning shows a significantly higher PHi loss than the other tasks, indicating that the model is performing non-trivial computation in its hidden states to infer the unknown automaton. Bootstrapped mean with 95% confidence intervals across 10 runs.
  • Figure 5: Token-wise PHi loss (y-axis) versus binned next-token losses (x-axis), stratified by PFA complexity (color, grouped into 10 levels, from 1 (simple) to 10 (complex)). Across all next-token prediction losses, more complex PFAs result in a higher PHi loss.
  • ...and 14 more figures