Table of Contents
Fetching ...

Masked Generative Priors Improve World Models Sequence Modelling Capabilities

Cristian Meo, Mircea Lica, Zarif Ikram, Akihiro Nakano, Vedant Shah, Aniket Rajiv Didolkar, Dianbo Liu, Anirudh Goyal, Justin Dauwels

TL;DR

This work tackles data efficiency in deep reinforcement learning by enhancing world models with a Masked Generative Prior (MaskGIT) within a Transformer-based dynamics framework. By replacing the MLP prior in STORM with MaskGIT and adding a state mixer to handle continuous actions, GIT-STORM achieves improved sequence modelling, leading to superior Atari 100k performance and strong results on continuous control tasks in the DeepMind Control Suite. The approach yields better imagined trajectory quality, evidenced by improved IQM and video-prediction metrics (FVD and perplexity), and broadens the applicability of transformer-based world models to continuous action domains. Overall, masked generative priors offer a versatile, high-signal inductive bias for more accurate world models and more effective RL policies across discrete and continuous tasks.

Abstract

Deep Reinforcement Learning (RL) has become the leading approach for creating artificial agents in complex environments. Model-based approaches, which are RL methods with world models that predict environment dynamics, are among the most promising directions for improving data efficiency, forming a critical step toward bridging the gap between research and real-world deployment. In particular, world models enhance sample efficiency by learning in imagination, which involves training a generative sequence model of the environment in a self-supervised manner. Recently, Masked Generative Modelling has emerged as a more efficient and superior inductive bias for modelling and generating token sequences. Building on the Efficient Stochastic Transformer-based World Models (STORM) architecture, we replace the traditional MLP prior with a Masked Generative Prior (e.g., MaskGIT Prior) and introduce GIT-STORM. We evaluate our model on two downstream tasks: reinforcement learning and video prediction. GIT-STORM demonstrates substantial performance gains in RL tasks on the Atari 100k benchmark. Moreover, we apply Transformer-based World Models to continuous action environments for the first time, addressing a significant gap in prior research. To achieve this, we employ a state mixer function that integrates latent state representations with actions, enabling our model to handle continuous control tasks. We validate this approach through qualitative and quantitative analyses on the DeepMind Control Suite, showcasing the effectiveness of Transformer-based World Models in this new domain. Our results highlight the versatility and efficacy of the MaskGIT dynamics prior, paving the way for more accurate world models and effective RL policies.

Masked Generative Priors Improve World Models Sequence Modelling Capabilities

TL;DR

This work tackles data efficiency in deep reinforcement learning by enhancing world models with a Masked Generative Prior (MaskGIT) within a Transformer-based dynamics framework. By replacing the MLP prior in STORM with MaskGIT and adding a state mixer to handle continuous actions, GIT-STORM achieves improved sequence modelling, leading to superior Atari 100k performance and strong results on continuous control tasks in the DeepMind Control Suite. The approach yields better imagined trajectory quality, evidenced by improved IQM and video-prediction metrics (FVD and perplexity), and broadens the applicability of transformer-based world models to continuous action domains. Overall, masked generative priors offer a versatile, high-signal inductive bias for more accurate world models and more effective RL policies across discrete and continuous tasks.

Abstract

Deep Reinforcement Learning (RL) has become the leading approach for creating artificial agents in complex environments. Model-based approaches, which are RL methods with world models that predict environment dynamics, are among the most promising directions for improving data efficiency, forming a critical step toward bridging the gap between research and real-world deployment. In particular, world models enhance sample efficiency by learning in imagination, which involves training a generative sequence model of the environment in a self-supervised manner. Recently, Masked Generative Modelling has emerged as a more efficient and superior inductive bias for modelling and generating token sequences. Building on the Efficient Stochastic Transformer-based World Models (STORM) architecture, we replace the traditional MLP prior with a Masked Generative Prior (e.g., MaskGIT Prior) and introduce GIT-STORM. We evaluate our model on two downstream tasks: reinforcement learning and video prediction. GIT-STORM demonstrates substantial performance gains in RL tasks on the Atari 100k benchmark. Moreover, we apply Transformer-based World Models to continuous action environments for the first time, addressing a significant gap in prior research. To achieve this, we employ a state mixer function that integrates latent state representations with actions, enabling our model to handle continuous control tasks. We validate this approach through qualitative and quantitative analyses on the DeepMind Control Suite, showcasing the effectiveness of Transformer-based World Models in this new domain. Our results highlight the versatility and efficacy of the MaskGIT dynamics prior, paving the way for more accurate world models and effective RL policies.

Paper Structure

This paper contains 45 sections, 11 equations, 17 figures, 15 tables, 2 algorithms.

Figures (17)

  • Figure 1: Overview of our proposed GIT-STORM method. (Left) The MaskGIT prior introduced to model the dynamics of the environment. The bidirectional transformer devlin2018bert combines the hidden state given by the autoregressive transformer and the masked posterior $z_t \circ m_t$ to produce the prior corresponding to the next timestep. (Right) MLP prior originally used in STORM.
  • Figure 2: (Left) Human normalized mean, across the Atari 100k benchmark. GIT-STORM outperforms all other baselines. (Middle) Human normalized median. TWM achieves the highest median value of $51\%$. (Right) IQM. GIT-STORM outperforms all other baselines.
  • Figure 3: Probability of Improvement of the mentioned baselines and GIT-STORM in the Atari 100k benchmark (Left) and DMC benchmark (Right). The results represent how likely it is for GIT-STORM to outperform other baselines.
  • Figure 4: Comparison of human normalized mean (left) and median (right) on DMC benchmark.
  • Figure 5: GIT-STORM End-to-End pipeline. Similar to STORM zhang2023storm, GIT-STORM performs sequence modelling using an autoregressive transformer, which predicts future stochastic latents, $z_t$, reward, $r_t$ and termination, $c_t$. In contrast with STORM, GIT-STORM uses a Masked Generative Prior to model the dynamics of the environment.
  • ...and 12 more figures