Table of Contents
Fetching ...

Accurate and Efficient World Modeling with Masked Latent Transformers

Maxime Burchi, Radu Timofte

TL;DR

The paper addresses the challenge of achieving both accuracy and efficiency in world modeling for model-based RL in complex environments. It introduces EMERALD, a masked latent Transformer world model that uses a spatial latent state and a MaskGIT predictor to generate accurate trajectories in latent space, enabling imagination-driven actor-critic learning. Empirical results on Crafter show state-of-the-art performance, surpassing human experts within 10M environment steps and unlocking all 22 achievements, while also demonstrating improved training efficiency over pixel-based approaches. The work highlights the advantages of combining spatial latents, Transformer memory, and masked latent decoding to preserve important perceptual details and long-term memory, with potential applicability to broader domains beyond Crafter.

Abstract

The Dreamer algorithm has recently obtained remarkable performance across diverse environment domains by training powerful agents with simulated trajectories. However, the compressed nature of its world model's latent space can result in the loss of crucial information, negatively affecting the agent's performance. Recent approaches, such as $Δ$-IRIS and DIAMOND, address this limitation by training more accurate world models. However, these methods require training agents directly from pixels, which reduces training efficiency and prevents the agent from benefiting from the inner representations learned by the world model. In this work, we propose an alternative approach to world modeling that is both accurate and efficient. We introduce EMERALD (Efficient MaskEd latent tRAnsformer worLD model), a world model using a spatial latent state with MaskGIT predictions to generate accurate trajectories in latent space and improve the agent performance. On the Crafter benchmark, EMERALD achieves new state-of-the-art performance, becoming the first method to surpass human experts performance within 10M environment steps. Our method also succeeds to unlock all 22 Crafter achievements at least once during evaluation.

Accurate and Efficient World Modeling with Masked Latent Transformers

TL;DR

The paper addresses the challenge of achieving both accuracy and efficiency in world modeling for model-based RL in complex environments. It introduces EMERALD, a masked latent Transformer world model that uses a spatial latent state and a MaskGIT predictor to generate accurate trajectories in latent space, enabling imagination-driven actor-critic learning. Empirical results on Crafter show state-of-the-art performance, surpassing human experts within 10M environment steps and unlocking all 22 achievements, while also demonstrating improved training efficiency over pixel-based approaches. The work highlights the advantages of combining spatial latents, Transformer memory, and masked latent decoding to preserve important perceptual details and long-term memory, with potential applicability to broader domains beyond Crafter.

Abstract

The Dreamer algorithm has recently obtained remarkable performance across diverse environment domains by training powerful agents with simulated trajectories. However, the compressed nature of its world model's latent space can result in the loss of crucial information, negatively affecting the agent's performance. Recent approaches, such as -IRIS and DIAMOND, address this limitation by training more accurate world models. However, these methods require training agents directly from pixels, which reduces training efficiency and prevents the agent from benefiting from the inner representations learned by the world model. In this work, we propose an alternative approach to world modeling that is both accurate and efficient. We introduce EMERALD (Efficient MaskEd latent tRAnsformer worLD model), a world model using a spatial latent state with MaskGIT predictions to generate accurate trajectories in latent space and improve the agent performance. On the Crafter benchmark, EMERALD achieves new state-of-the-art performance, becoming the first method to surpass human experts performance within 10M environment steps. Our method also succeeds to unlock all 22 Crafter achievements at least once during evaluation.

Paper Structure

This paper contains 30 sections, 7 equations, 11 figures, 11 tables.

Figures (11)

  • Figure 1: Achievements score and collected Frames Per Second (FPS) of recently published model-based methods on the Crafter benchmark. EMERALD is the first method to exceed the performance of human experts on the benchmark. The world model uses a spatial latent state with MaskGIT predictions to improve performance and training efficiency. $\Delta$-IRIS proposed an accurate world model but suffers from lower training efficiency due to autoregressive decoding in latent space and agent learning from reconstructed images.
  • Figure 2: Comparison of EMERALD image reconstruction with DreamerV3. We show the 5 frames with highest reconstruction error among a batch of 1024 test observations. The top row indicates original images, the middle row shows reconstructed images and the bottom row shows reconstruction error. We observe that EMERALD achieves near-perfect reconstruction, with errors resulting mostly from player orientation and textures. DreamerV3 fails to perceive crucial details like diamonds and skeleton arrows, which diminishes the agent's perception capacity and negatively impacts its performance.
  • Figure 3: Efficient masked latent Transformer-based world model. EMERALD uses a spatial latent state $z_{t}$ and a temporal hidden state $h_{t}$ to model the environment accurately and effectively. The world model predictions are made using a spatial MaskGIT predictor network to increase decoding speed while maintaining accuracy. Actor-critic learning is performed by imagining trajectories in latent space which allows the agents to benefit from world model inner representations.
  • Figure 4: Comparison of scheduled parallel decoding in EMERALD vs. autoregressive sequential decoding used by IRIS and $\Delta$-IRIS. We illustrate the $H \times W \times G$ latent space of our method with only 4 of the 32 groups for better clarity. Sequential decoding predicts one token at a time, significantly impacting efficiency. In contrast, EMERALD uses parallel predictions with scheduled refinements, reducing decoding time while preserving the coherence of predicted tokens.
  • Figure 5: Cosine mask schedule. During training, we uniformly sample decoding times $\tau$ between 0 and 1. During imagination, the world model samples all masked tokens and refines $N=\lfloor \gamma HWG \rfloor$ tokens with lower probability. We illustrate the schedule used for $S=3$ decoding steps.
  • ...and 6 more figures