Table of Contents
Fetching ...

JEDI: Latent End-to-end Diffusion Mitigates Agent-Human Performance Asymmetry in Model-Based Reinforcement Learning

Jing Yu Lim, Zarif Ikram, Samson Yu, Haozhe Ma, Tze-Yun Leong, Dianbo Liu

TL;DR

This work tackles the misalignment between human and agent performance in Atari100k by revealing a pronounced task-wise asymmetry in pixel-based MBRL and proposing JEDI, a latent end-to-end diffusion world model trained with a self-consistency objective inspired by JEPA. By integrating a latent diffusion dynamics model with an end-to-end encoder, JEDI achieves state-of-the-art results on human-optimal tasks while staying competitive overall, and it does so with substantial efficiency gains (faster inference, faster training, and lower memory) thanks to latent compression. The key contribution is demonstrating that temporally structured latent representations learned end-to-end via diffusion can bridge the gap between agent and human performance, challenging prior reliance on pixel-space diffusion and detached encoders. Overall, the work advances holistic human-level performance assessment in Atari100k and offers a scalable, efficient framework for temporally aware, model-based RL.

Abstract

Recent advances in model-based reinforcement learning (MBRL) have achieved super-human level performance on the Atari100k benchmark, driven by reinforcement learning agents trained on powerful diffusion world models. However, we identify that the current aggregates mask a major performance asymmetry: MBRL agents dramatically outperform humans in some tasks despite drastically underperforming in others, with the former inflating the aggregate metrics. This is especially pronounced in pixel-based agents trained with diffusion world models. In this work, we address the pronounced asymmetry observed in pixel-based agents as an initial attempt to reverse the worrying upward trend observed in them. We address the problematic aggregates by delineating all tasks as Agent-Optimal or Human-Optimal and advocate for equal importance on metrics from both sets. Next, we hypothesize this pronounced asymmetry is due to the lack of temporally-structured latent space trained with the World Model objective in pixel-based methods. Lastly, to address this issue, we propose Joint Embedding DIffusion (JEDI), a novel latent diffusion world model trained end-to-end with the self-consistency objective. JEDI outperforms SOTA models in human-optimal tasks while staying competitive across the Atari100k benchmark, and runs 3 times faster with 43% lower memory than the latest pixel-based diffusion baseline. Overall, our work rethinks what it truly means to cross human-level performance in Atari100k.

JEDI: Latent End-to-end Diffusion Mitigates Agent-Human Performance Asymmetry in Model-Based Reinforcement Learning

TL;DR

This work tackles the misalignment between human and agent performance in Atari100k by revealing a pronounced task-wise asymmetry in pixel-based MBRL and proposing JEDI, a latent end-to-end diffusion world model trained with a self-consistency objective inspired by JEPA. By integrating a latent diffusion dynamics model with an end-to-end encoder, JEDI achieves state-of-the-art results on human-optimal tasks while staying competitive overall, and it does so with substantial efficiency gains (faster inference, faster training, and lower memory) thanks to latent compression. The key contribution is demonstrating that temporally structured latent representations learned end-to-end via diffusion can bridge the gap between agent and human performance, challenging prior reliance on pixel-space diffusion and detached encoders. Overall, the work advances holistic human-level performance assessment in Atari100k and offers a scalable, efficient framework for temporally aware, model-based RL.

Abstract

Recent advances in model-based reinforcement learning (MBRL) have achieved super-human level performance on the Atari100k benchmark, driven by reinforcement learning agents trained on powerful diffusion world models. However, we identify that the current aggregates mask a major performance asymmetry: MBRL agents dramatically outperform humans in some tasks despite drastically underperforming in others, with the former inflating the aggregate metrics. This is especially pronounced in pixel-based agents trained with diffusion world models. In this work, we address the pronounced asymmetry observed in pixel-based agents as an initial attempt to reverse the worrying upward trend observed in them. We address the problematic aggregates by delineating all tasks as Agent-Optimal or Human-Optimal and advocate for equal importance on metrics from both sets. Next, we hypothesize this pronounced asymmetry is due to the lack of temporally-structured latent space trained with the World Model objective in pixel-based methods. Lastly, to address this issue, we propose Joint Embedding DIffusion (JEDI), a novel latent diffusion world model trained end-to-end with the self-consistency objective. JEDI outperforms SOTA models in human-optimal tasks while staying competitive across the Atari100k benchmark, and runs 3 times faster with 43% lower memory than the latest pixel-based diffusion baseline. Overall, our work rethinks what it truly means to cross human-level performance in Atari100k.

Paper Structure

This paper contains 19 sections, 7 equations, 14 figures, 7 tables, 1 algorithm.

Figures (14)

  • Figure 1: Left: Performance assymetry on RL tasks. Tasks where humans excel often differ from where RL agents excel---often due to reward hacking or representation learning failures skalse2022RewardHackguo2021machine. Right: MBRL agents exhibit a large performance asymmetry between Human-Optimal and Agent-optimal tasks, often outperforming in the latter by over an order of magnitude. This disparity is especially pronounced in pixel-based MBRL agents.
  • Figure 2: Left: Example trajectory of DIAMOND on BankHeist displays self-sabotaging behaviour as it is unable to reason about its high complexity action space. Right: DIAMOND vs. other agents in different environment settings. DIAMOND's performance bias towards Non-Action environments with Low Action Space.
  • Figure 3: Joint Embedding DIffusion (JEDI) World Model. (a) During training, the current image observations $[x_{t-3}, ... x_t]$ is passed through the world model encoder to derive the low dimensional latent states $[z_{t-3}^0, ..., z_t^0]$ and it is passed into the diffusion model together with the actions as the conditioning. The target next state $z_{t+1}^\tau$ is derived from passing $x_{t+1}$ through the same encoder, except with stop-gradient. A noise is sampled and scaled according to the sampled diffusion time step $\tau$, and this is passed to the diffusion model as input. Given these, the diffusion model learns to predict the direction towards the next state, training both the encoder and diffusion model in an end-to-end fashion. (b) During inference, the same conditioning is derived and passed into the diffusion model. A random noise is sampled and passed into diffusion model as input. Given these, the model predicts the direction towards the next state and iteratively denoise to arrive at the clean next state.
  • Figure 4: Overall aggregates on Atari100k. Mean, Median, and Interquantile Mean (IQM) is with reference to Human Normalized Scores (↑); Optimality Gap (↓) is the overall gap to human-level performance. JEDI achieves SOTA in optimality gap while achieving runner-up performance in Mean. It outperforms the SOTA pixel-based agent baseline in Median.
  • Figure 5: JEDI vs. DreamerV3 in environments with stochastic frame-skipping. JEDI consistently outperform SOTA latent-based agent in all aspects. JEDI's performance is more robust to high task-related aleotoric uncertainty.
  • ...and 9 more figures