TD-JEPA: Latent-predictive Representations for Zero-Shot Reinforcement Learning
Marco Bagatella, Matteo Pirotta, Ahmed Touati, Alessandro Lazaric, Andrea Tirinzoni
TL;DR
TD-JEPA proposes a temporal-difference, latent-predictive framework to learn representations that predict long-horizon latent dynamics across multiple policies from reward-free offline data. By training separate state $\phi$ and task $\psi$ encoders alongside a policy-conditioned multi-step predictor, and leveraging TD losses, it enables zero-shot optimization of any downstream reward directly in latent space, with policies $\pi_z(\phi(s))$ derived from predicting $F$-like successor features. Theoretical results show non-collapse under suitable initialization and a low-rank factorization of long-term policy dynamics, with predictor outputs approximating successor features in latent space; policy evaluation errors are bounded by the successor-measure loss, enabling robust zero-shot performance. Empirically, TD-JEPA matches or surpasses strong zero-shot baselines across locomotion, navigation, and manipulation tasks, particularly excelling with pixel inputs and enabling fast downstream adaptation through pre-trained latent representations.
Abstract
Latent prediction--where agents learn by predicting their own latents--has emerged as a powerful paradigm for training general representations in machine learning. In reinforcement learning (RL), this approach has been explored to define auxiliary losses for a variety of settings, including reward-based and unsupervised RL, behavior cloning, and world modeling. While existing methods are typically limited to single-task learning, one-step prediction, or on-policy trajectory data, we show that temporal difference (TD) learning enables learning representations predictive of long-term latent dynamics across multiple policies from offline, reward-free transitions. Building on this, we introduce TD-JEPA, which leverages TD-based latent-predictive representations into unsupervised RL. TD-JEPA trains explicit state and task encoders, a policy-conditioned multi-step predictor, and a set of parameterized policies directly in latent space. This enables zero-shot optimization of any reward function at test time. Theoretically, we show that an idealized variant of TD-JEPA avoids collapse with proper initialization, and learns encoders that capture a low-rank factorization of long-term policy dynamics, while the predictor recovers their successor features in latent space. Empirically, TD-JEPA matches or outperforms state-of-the-art baselines on locomotion, navigation, and manipulation tasks across 13 datasets in ExoRL and OGBench, especially in the challenging setting of zero-shot RL from pixels.
