Table of Contents
Fetching ...

Transformers Can Learn Temporal Difference Methods for In-Context Reinforcement Learning

Jiuqi Wang, Ethan Blaser, Hadi Daneshmand, Shangtong Zhang

TL;DR

This paper investigates in-context reinforcement learning (ICRL) with a focus on policy evaluation, providing both empirical and theoretical evidence that a pretrained transformer can implement temporal-difference (TD) methods in its forward pass. It constructs an explicit linear-transformer setup that performs batch TD(0) during inference and shows that, under TD-based multi-task pretraining, the network weights converge to TD-like structures within an invariant set of the training dynamics. The authors extend the analysis to TD(λ) and outline how residual gradient and average-reward TD can also be realized in-context, arguing for a broader class of in-context RL capabilities. The work positions TD as a natural emergent algorithm from reinforcement pretraining and proposes a path toward white-box understanding of ICRL, while noting limitations such as reliance on linear attention and policy-evaluation-centric experiments. Overall, the findings suggest that forward-pass RL algorithms can be learned and deployed without parameter updates, with potential implications for rapid generalization across unseen tasks.

Abstract

Traditionally, reinforcement learning (RL) agents learn to solve new tasks by updating their neural network parameters through interactions with the task environment. However, recent works demonstrate that some RL agents, after certain pretraining procedures, can learn to solve unseen new tasks without parameter updates, a phenomenon known as in-context reinforcement learning (ICRL). The empirical success of ICRL is widely attributed to the hypothesis that the forward pass of the pretrained agent neural network implements an RL algorithm. In this paper, we support this hypothesis by showing, both empirically and theoretically, that when a transformer is trained for policy evaluation tasks, it can discover and learn to implement temporal difference learning in its forward pass.

Transformers Can Learn Temporal Difference Methods for In-Context Reinforcement Learning

TL;DR

This paper investigates in-context reinforcement learning (ICRL) with a focus on policy evaluation, providing both empirical and theoretical evidence that a pretrained transformer can implement temporal-difference (TD) methods in its forward pass. It constructs an explicit linear-transformer setup that performs batch TD(0) during inference and shows that, under TD-based multi-task pretraining, the network weights converge to TD-like structures within an invariant set of the training dynamics. The authors extend the analysis to TD(λ) and outline how residual gradient and average-reward TD can also be realized in-context, arguing for a broader class of in-context RL capabilities. The work positions TD as a natural emergent algorithm from reinforcement pretraining and proposes a path toward white-box understanding of ICRL, while noting limitations such as reliance on linear attention and policy-evaluation-centric experiments. Overall, the findings suggest that forward-pass RL algorithms can be learned and deployed without parameter updates, with potential implications for rapid generalization across unseen tasks.

Abstract

Traditionally, reinforcement learning (RL) agents learn to solve new tasks by updating their neural network parameters through interactions with the task environment. However, recent works demonstrate that some RL agents, after certain pretraining procedures, can learn to solve unseen new tasks without parameter updates, a phenomenon known as in-context reinforcement learning (ICRL). The empirical success of ICRL is widely attributed to the hypothesis that the forward pass of the pretrained agent neural network implements an RL algorithm. In this paper, we support this hypothesis by showing, both empirically and theoretically, that when a transformer is trained for policy evaluation tasks, it can discover and learn to implement temporal difference learning in its forward pass.
Paper Structure (32 sections, 8 theorems, 203 equations, 13 figures, 4 algorithms)

This paper contains 32 sections, 8 theorems, 203 equations, 13 figures, 4 algorithms.

Key Result

Theorem 1

Consider the L-layer linear transformer following eq: Z_l update, using the mask eq td0 mask, parameterized by $\qty{P_l^{\text{TD}}, Q_l^{\text{TD}}}_{l=0,\dots, L-1}$ in eq: Z_0 P Q TD define. Let $y_l^{(n+1)}$ be the bottom right element of the $l$-th layer's output, i.e., $y_l^{(n+1)} \doteq Z_l

Figures (13)

  • Figure 1: A transformer capable of in-context policy evaluation. This 15-layer transformer $\text{TF}_{\theta_*}$ takes the context $\tau_t$ and a state of interest $s$ as input and outputs $\text{TF}_{\theta_*}(\tau_t, s)$ as the estimation of the state value $v_\pi(s)$. The $y$-axis is the mean square value error (MSVE) $\sum_s d_\pi(s) \qty(\text{TF}_{\theta_*}(\tau_t, s)- v_\pi(s))^2$, with $d_\pi(s)$ being the stationary state distribution. The curves are averaged over 300 randomly generated policy evaluation tasks, with shaded regions being standard errors. The tasks vary in state space, transition function, reward function, and policy. Yet a single $\theta_*$ is used for all tasks. See Appendix \ref{['appendix: demo']} for more details.
  • Figure 2: Visualization of the learned transformers and the learning progress. Both (a) and (b) are averaged across 30 seeds and the shaded regions in (b) denotes the standard errors. Since $P_0$ and $Q_0$ are in the same product in \ref{['eq linear attention']}, the algorithm can rescale both or flip the sign of both, but still end up with exactly the same transformer. Therefore, to make sure the visualization are informative, we rescale $P_0$ and $Q_0$ properly first before visualization. See Appendix \ref{['appendix: element-wise metrics']} for details.
  • Figure 3: Boyan's Chain of $m$ States
  • Figure 4: Visualization of the learned autoregressive transformers and the learning progress. Averaged across 30 seeds and the shaded region denotes the standard errors. See Appendix \ref{['appendix: element-wise metrics']} for details about normalization of $P_0$ and $Q_0$ before visualization.
  • Figure 5: Value difference (VD), implicit weight similarity (IWS), and sensitivity similarity (SS) between the learned autoregressive transformers and batch TD with different layers. All curves are averaged over 30 seeds and the shaded regions are the standard errors.
  • ...and 8 more figures

Theorems & Definitions (16)

  • Theorem 1: Forward pass as TD(0)
  • Corollary 1
  • Theorem 2
  • Corollary 2: Forward pass as Residual Gradient
  • Corollary 3: Forward pass as TD($\lambda$)
  • Theorem 3: Forward pass as average-reward TD
  • proof
  • proof
  • Lemma A.3.1
  • proof
  • ...and 6 more