Table of Contents
Fetching ...

TD-JEPA: Latent-predictive Representations for Zero-Shot Reinforcement Learning

Marco Bagatella, Matteo Pirotta, Ahmed Touati, Alessandro Lazaric, Andrea Tirinzoni

TL;DR

TD-JEPA proposes a temporal-difference, latent-predictive framework to learn representations that predict long-horizon latent dynamics across multiple policies from reward-free offline data. By training separate state $\phi$ and task $\psi$ encoders alongside a policy-conditioned multi-step predictor, and leveraging TD losses, it enables zero-shot optimization of any downstream reward directly in latent space, with policies $\pi_z(\phi(s))$ derived from predicting $F$-like successor features. Theoretical results show non-collapse under suitable initialization and a low-rank factorization of long-term policy dynamics, with predictor outputs approximating successor features in latent space; policy evaluation errors are bounded by the successor-measure loss, enabling robust zero-shot performance. Empirically, TD-JEPA matches or surpasses strong zero-shot baselines across locomotion, navigation, and manipulation tasks, particularly excelling with pixel inputs and enabling fast downstream adaptation through pre-trained latent representations.

Abstract

Latent prediction--where agents learn by predicting their own latents--has emerged as a powerful paradigm for training general representations in machine learning. In reinforcement learning (RL), this approach has been explored to define auxiliary losses for a variety of settings, including reward-based and unsupervised RL, behavior cloning, and world modeling. While existing methods are typically limited to single-task learning, one-step prediction, or on-policy trajectory data, we show that temporal difference (TD) learning enables learning representations predictive of long-term latent dynamics across multiple policies from offline, reward-free transitions. Building on this, we introduce TD-JEPA, which leverages TD-based latent-predictive representations into unsupervised RL. TD-JEPA trains explicit state and task encoders, a policy-conditioned multi-step predictor, and a set of parameterized policies directly in latent space. This enables zero-shot optimization of any reward function at test time. Theoretically, we show that an idealized variant of TD-JEPA avoids collapse with proper initialization, and learns encoders that capture a low-rank factorization of long-term policy dynamics, while the predictor recovers their successor features in latent space. Empirically, TD-JEPA matches or outperforms state-of-the-art baselines on locomotion, navigation, and manipulation tasks across 13 datasets in ExoRL and OGBench, especially in the challenging setting of zero-shot RL from pixels.

TD-JEPA: Latent-predictive Representations for Zero-Shot Reinforcement Learning

TL;DR

TD-JEPA proposes a temporal-difference, latent-predictive framework to learn representations that predict long-horizon latent dynamics across multiple policies from reward-free offline data. By training separate state and task encoders alongside a policy-conditioned multi-step predictor, and leveraging TD losses, it enables zero-shot optimization of any downstream reward directly in latent space, with policies derived from predicting -like successor features. Theoretical results show non-collapse under suitable initialization and a low-rank factorization of long-term policy dynamics, with predictor outputs approximating successor features in latent space; policy evaluation errors are bounded by the successor-measure loss, enabling robust zero-shot performance. Empirically, TD-JEPA matches or surpasses strong zero-shot baselines across locomotion, navigation, and manipulation tasks, particularly excelling with pixel inputs and enabling fast downstream adaptation through pre-trained latent representations.

Abstract

Latent prediction--where agents learn by predicting their own latents--has emerged as a powerful paradigm for training general representations in machine learning. In reinforcement learning (RL), this approach has been explored to define auxiliary losses for a variety of settings, including reward-based and unsupervised RL, behavior cloning, and world modeling. While existing methods are typically limited to single-task learning, one-step prediction, or on-policy trajectory data, we show that temporal difference (TD) learning enables learning representations predictive of long-term latent dynamics across multiple policies from offline, reward-free transitions. Building on this, we introduce TD-JEPA, which leverages TD-based latent-predictive representations into unsupervised RL. TD-JEPA trains explicit state and task encoders, a policy-conditioned multi-step predictor, and a set of parameterized policies directly in latent space. This enables zero-shot optimization of any reward function at test time. Theoretically, we show that an idealized variant of TD-JEPA avoids collapse with proper initialization, and learns encoders that capture a low-rank factorization of long-term policy dynamics, while the predictor recovers their successor features in latent space. Empirically, TD-JEPA matches or outperforms state-of-the-art baselines on locomotion, navigation, and manipulation tasks across 13 datasets in ExoRL and OGBench, especially in the challenging setting of zero-shot RL from pixels.

Paper Structure

This paper contains 51 sections, 7 theorems, 58 equations, 9 figures, 5 tables, 2 algorithms.

Key Result

Proposition 1

For any $\phi$ and $T_\phi$, we have the following equivalence

Figures (9)

  • Figure 1: TD-JEPA trains policies $\pi_z$ parameterized by latents $z$. The predictor, conditioned on $z$, predicts the representations of future states visited by $\pi_z$ (left). When trained via TD, the predictor (arrows on the right) approximates successor features for each policy, i.e., the weighted barycenter (stars) of representations of visited states (circles).
  • Figure 2: Probabilities of improvement: how lixely is method X to outperform method Y on a random domain? We report symmetrized 95% simple bootstrap confidence intervals. Dotted lines surround matches in which the improvement is statistically significant.
  • Figure 3: Left: normalized zero-shot performance for latent-predictive methods. Right: difference in normalized performance between TD-JEPA and its symmetric variant. Error bars represent standard errors on normalized performance or its differences, respectively.
  • Figure 4: Normalized performance of zero-shot policies when fine-tuned offline (top) or online (bottom). Initializing the agent to zero-shot solutions (blue and yellow lines) results in sample-efficient learning; frozen representations (dashed) are often expressive enough to enable fast adaptation.
  • Figure 5: Difference in normalized performance between zero-shot baselines with and without an explicit encoder (left); normalized performance difference between symmetric TD-JEPA and its contrastive variant (right). Error bars represent standard errors on normalized performance differences.
  • ...and 4 more figures

Theorems & Definitions (10)

  • Proposition 1
  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Theorem 4
  • Remark 1
  • Theorem 5
  • proof
  • Proposition 2
  • proof