Table of Contents
Fetching ...

Sparse Autoencoders Reveal Temporal Difference Learning in Large Language Models

Can Demircan, Tankred Saanum, Akshay K. Jagadish, Marcel Binz, Eric Schulz

TL;DR

This work investigates how large language models perform reinforcement learning in-context by analyzing Llama 3 70B with Sparse Autoencoders to extract low-dimensional, TD-like latents from the residual stream. The authors show that representations resembling TD errors, $Q$-values, and successor representations can emerge across transformer blocks even though the model is trained only on next-token prediction, and they demonstrate causal roles for these latents via targeted interventions. Across three tasks—Two-Step, Grid World, and a graph-learning paradigm—the approach reveals both local and global RL-like structure and demonstrates that manipulating TD latents can systematically alter policy and internal representations. The study offers a concrete methodology for mechanistic in-context learning analysis and links computational ideas from reinforcement learning to neural representations observed in both artificial and biological systems.

Abstract

In-context learning, the ability to adapt based on a few examples in the input prompt, is a ubiquitous feature of large language models (LLMs). However, as LLMs' in-context learning abilities continue to improve, understanding this phenomenon mechanistically becomes increasingly important. In particular, it is not well-understood how LLMs learn to solve specific classes of problems, such as reinforcement learning (RL) problems, in-context. Through three different tasks, we first show that Llama $3$ $70$B can solve simple RL problems in-context. We then analyze the residual stream of Llama using Sparse Autoencoders (SAEs) and find representations that closely match temporal difference (TD) errors. Notably, these representations emerge despite the model only being trained to predict the next token. We verify that these representations are indeed causally involved in the computation of TD errors and $Q$-values by performing carefully designed interventions on them. Taken together, our work establishes a methodology for studying and manipulating in-context learning with SAEs, paving the way for a more mechanistic understanding.

Sparse Autoencoders Reveal Temporal Difference Learning in Large Language Models

TL;DR

This work investigates how large language models perform reinforcement learning in-context by analyzing Llama 3 70B with Sparse Autoencoders to extract low-dimensional, TD-like latents from the residual stream. The authors show that representations resembling TD errors, -values, and successor representations can emerge across transformer blocks even though the model is trained only on next-token prediction, and they demonstrate causal roles for these latents via targeted interventions. Across three tasks—Two-Step, Grid World, and a graph-learning paradigm—the approach reveals both local and global RL-like structure and demonstrates that manipulating TD latents can systematically alter policy and internal representations. The study offers a concrete methodology for mechanistic in-context learning analysis and links computational ideas from reinforcement learning to neural representations observed in both artificial and biological systems.

Abstract

In-context learning, the ability to adapt based on a few examples in the input prompt, is a ubiquitous feature of large language models (LLMs). However, as LLMs' in-context learning abilities continue to improve, understanding this phenomenon mechanistically becomes increasingly important. In particular, it is not well-understood how LLMs learn to solve specific classes of problems, such as reinforcement learning (RL) problems, in-context. Through three different tasks, we first show that Llama B can solve simple RL problems in-context. We then analyze the residual stream of Llama using Sparse Autoencoders (SAEs) and find representations that closely match temporal difference (TD) errors. Notably, these representations emerge despite the model only being trained to predict the next token. We verify that these representations are indeed causally involved in the computation of TD errors and -values by performing carefully designed interventions on them. Taken together, our work establishes a methodology for studying and manipulating in-context learning with SAEs, paving the way for a more mechanistic understanding.
Paper Structure (18 sections, 5 equations, 10 figures)

This paper contains 18 sections, 5 equations, 10 figures.

Figures (10)

  • Figure 1: We study the mechanisms of in-context learning in three different tasks: (A) The Two-Step Task, (B) the Grid World task, and (C) the graph prediction task. (D) Example pipeline for the Two-Step Task. We prompt Llama as shown on the left. As it selects its actions, we record the internal representation for the tokens that precede the actions, which are highlighted in orange. We train SAEs on these representations (middle) and correlate with the learned latents of the SAEs against the TD error signals and other variables of interest we obtain from reinforcement learning agents (right). After identifying such latents, we lesion them and replace Llama's internal representations with reconstructions following the lesions (middle). We then test whether the lesion created the expected effects in the behavior (right).
  • Figure 2: Llama relies on TD-like features to solve RL tasks in-context. (A) Llama 70B often learns the optimal policy in the Two-Step Task through trial and error, whereas the smaller 8B counterpart does not improve beyond chance level. Shaded regions show standard error of the mean. (B) Llama's behavior is best described by a $Q$-learning algorithm. (C & D) SAE features with significant correlations to both reward estimates (myopic values) and $Q$-value estimates, as well as temporal difference errors, appear gradually through the transformer blocks. (E and F) Deactivating a single TD feature in Llama is sufficient to impair performance and make behavior less consistent with $Q$-learning. (G & H) Negatively clamping the TD feature decreases subsequent representations' similarity to $Q$-values and TD errors.
  • Figure 3: Qualitative comparison of the TD error and the best matching SAE feature from block 34 for three separate runs in all three variations of the Two-Step Task. The SAE feature shows similar jumps as TD when Llama encounters surprising events, such as transitioning to unexpected states or receiving an unexpected reward. The dashed line indicates the onset of the change in either transition dynamics or reward function.
  • Figure 4: Llama can predict the actions of a $Q$-learning agent and keeps track of variables similar to $Q$-values and TD errors. (A) Llama predicts action sequences better when given correct information about rewards. (B & C) $Q$-values and a reward-tracking variable, as well as accompanying error signals, are significantly correlated with SAE features. Max correlations shown after Gaussian smoothing with $\sigma=0.5$. (D) Lesioning the TD latents impacts action prediction accuracy, whereas lesioning other features barely affects action predictions. (E & F) Lesioning TD latents also impacts subsequent $Q$-value and TD error representations.
  • Figure 5: Llama learns graph structures through TD-learning, representing them similarly to the successor representation (SR). (A) Llama's state representations projected in 2D space, using multidimensional scaling, shows the emergence of latent graph structure across transformer blocks. (B) Llama quickly achieves high accuracy in predicting the next state. Accuracy is averaged over $100$ runs. (C) Bottleneck states can be linearly decoded from middle blocks onward. (D & E) Latent representations of SAEs trained on Llama's representations strongly correlate with the SR and associated TD learning signals, outperforming model-based alternatives. Shaded regions in B-E indicate $95\%$ confidence intervals.
  • ...and 5 more figures