Table of Contents
Fetching ...

Learning Transformer-based World Models with Contrastive Predictive Coding

Maxime Burchi, Radu Timofte

TL;DR

TWISTER tackles the limitations of Transformer-based world models by introducing action-conditioned CPC to learn long-horizon temporal representations. It combines a Transformer State-Space Model with discrete latent encodings and an AC-CPC objective to produce predictive, high-level features used by latent-space actor-critic learning. On Atari 100k, TWISTER achieves a human-normalized mean of 162% and a median of 77%, setting a new record among non-lookahead model-based methods, with DeepMind Control Suite results also demonstrating strong performance. Ablation studies highlight the importance of CPC horizon, action conditioning, and data augmentation, and show that Transformer architectures better exploit AC-CPC than RSSM-based variants. Overall, the work underscores the value of horizon-aware self-supervised representations for efficient planning in model-based reinforcement learning.

Abstract

The DreamerV3 algorithm recently obtained remarkable performance across diverse environment domains by learning an accurate world model based on Recurrent Neural Networks (RNNs). Following the success of model-based reinforcement learning algorithms and the rapid adoption of the Transformer architecture for its superior training efficiency and favorable scaling properties, recent works such as STORM have proposed replacing RNN-based world models with Transformer-based world models using masked self-attention. However, despite the improved training efficiency of these methods, their impact on performance remains limited compared to the Dreamer algorithm, struggling to learn competitive Transformer-based world models. In this work, we show that the next state prediction objective adopted in previous approaches is insufficient to fully exploit the representation capabilities of Transformers. We propose to extend world model predictions to longer time horizons by introducing TWISTER (Transformer-based World model wIth contraSTivE Representations), a world model using action-conditioned Contrastive Predictive Coding to learn high-level temporal feature representations and improve the agent performance. TWISTER achieves a human-normalized mean score of 162% on the Atari 100k benchmark, setting a new record among state-of-the-art methods that do not employ look-ahead search.

Learning Transformer-based World Models with Contrastive Predictive Coding

TL;DR

TWISTER tackles the limitations of Transformer-based world models by introducing action-conditioned CPC to learn long-horizon temporal representations. It combines a Transformer State-Space Model with discrete latent encodings and an AC-CPC objective to produce predictive, high-level features used by latent-space actor-critic learning. On Atari 100k, TWISTER achieves a human-normalized mean of 162% and a median of 77%, setting a new record among non-lookahead model-based methods, with DeepMind Control Suite results also demonstrating strong performance. Ablation studies highlight the importance of CPC horizon, action conditioning, and data augmentation, and show that Transformer architectures better exploit AC-CPC than RSSM-based variants. Overall, the work underscores the value of horizon-aware self-supervised representations for efficient planning in model-based reinforcement learning.

Abstract

The DreamerV3 algorithm recently obtained remarkable performance across diverse environment domains by learning an accurate world model based on Recurrent Neural Networks (RNNs). Following the success of model-based reinforcement learning algorithms and the rapid adoption of the Transformer architecture for its superior training efficiency and favorable scaling properties, recent works such as STORM have proposed replacing RNN-based world models with Transformer-based world models using masked self-attention. However, despite the improved training efficiency of these methods, their impact on performance remains limited compared to the Dreamer algorithm, struggling to learn competitive Transformer-based world models. In this work, we show that the next state prediction objective adopted in previous approaches is insufficient to fully exploit the representation capabilities of Transformers. We propose to extend world model predictions to longer time horizons by introducing TWISTER (Transformer-based World model wIth contraSTivE Representations), a world model using action-conditioned Contrastive Predictive Coding to learn high-level temporal feature representations and improve the agent performance. TWISTER achieves a human-normalized mean score of 162% on the Atari 100k benchmark, setting a new record among state-of-the-art methods that do not employ look-ahead search.

Paper Structure

This paper contains 30 sections, 6 equations, 12 figures, 11 tables.

Figures (12)

  • Figure 1: Human-normalized mean and median scores of recently published model-based methods on the Atari 100k benchmark. TWISTER outperforms other model-based approaches. TWM, IRIS, STORM and $\Delta$-IRIS employ a Transformer-based world model while DreamerV3 uses a RNN-based model.
  • Figure 2: Cosine Similarities between TWISTER latent state $z_{t}$ and future states $z_{t+k}$ aggregated over all 26 games of the Atari 100k benchmark. We show average similarities over 5 seeds.
  • Figure 3: Transformer-based world model with contrastive representations. The world model learns temporal feature representations by maximizing the mutual information between model states $s_{t}$ and future stochastic states $z'_{t:t+K}$ obtained from augmented views of image observations. The encoder network converts image observations into stochastic states $z_{t}$, from which a decoder network learns to reconstruct images while the masked attention Transformer network predicts next episode continuations, rewards and stochastic states conditioned on selected actions.
  • Figure 4: AC-CPC predictions made by the world model. We show the target positive sample without augmentation and the predicted most/least similar samples among the batch of augmented image views. We observe that TWISTER learns to identify most/least similar samples to the future target state using observation details such as the ball position, game score or agent movements. AC-CPC necessitates the agent to focus on observation details to accurately predict future samples, thereby preventing common failure cases where small objects are ignored by the reconstruction loss.
  • Figure 5: Mean and median scores, computed with stratified bootstrap confidence intervals agarwal2021deep. TWISTER achieves a normalized mean of 1.62 and a median of 0.77.
  • ...and 7 more figures