Table of Contents
Fetching ...

Learning to Play Atari in a World of Tokens

Pranav Agarwal, Sheldon Andrews, Samira Ebrahimi Kahou

TL;DR

This paper tackles sample-inefficient model-based RL by introducing DART, a transformer-based framework that operates on discrete latent representations. It employs a transformer-decoder to model world dynamics from tokenized observations and a transformer-encoder to learn policy behavior, augmented with a memory token to address partial observability. DART delivers state-of-the-art Atari 100k performance among no-look-ahead methods, achieving a median human-normalized score of $0.790$ and superhuman results in $9$ of $26$ games, with strong interpretability via attention analyses and targeted ablations. The work suggests that discrete representations and memory-augmented transformers can yield both improved sample efficiency and clearer behavioral insights, while outlining extensions to lookahead planning and continuous-action spaces for broader applicability.

Abstract

Model-based reinforcement learning agents utilizing transformers have shown improved sample efficiency due to their ability to model extended context, resulting in more accurate world models. However, for complex reasoning and planning tasks, these methods primarily rely on continuous representations. This complicates modeling of discrete properties of the real world such as disjoint object classes between which interpolation is not plausible. In this work, we introduce discrete abstract representations for transformer-based learning (DART), a sample-efficient method utilizing discrete representations for modeling both the world and learning behavior. We incorporate a transformer-decoder for auto-regressive world modeling and a transformer-encoder for learning behavior by attending to task-relevant cues in the discrete representation of the world model. For handling partial observability, we aggregate information from past time steps as memory tokens. DART outperforms previous state-of-the-art methods that do not use look-ahead search on the Atari 100k sample efficiency benchmark with a median human-normalized score of 0.790 and beats humans in 9 out of 26 games. We release our code at https://pranaval.github.io/DART/.

Learning to Play Atari in a World of Tokens

TL;DR

This paper tackles sample-inefficient model-based RL by introducing DART, a transformer-based framework that operates on discrete latent representations. It employs a transformer-decoder to model world dynamics from tokenized observations and a transformer-encoder to learn policy behavior, augmented with a memory token to address partial observability. DART delivers state-of-the-art Atari 100k performance among no-look-ahead methods, achieving a median human-normalized score of and superhuman results in of games, with strong interpretability via attention analyses and targeted ablations. The work suggests that discrete representations and memory-augmented transformers can yield both improved sample efficiency and clearer behavioral insights, while outlining extensions to lookahead planning and continuous-action spaces for broader applicability.

Abstract

Model-based reinforcement learning agents utilizing transformers have shown improved sample efficiency due to their ability to model extended context, resulting in more accurate world models. However, for complex reasoning and planning tasks, these methods primarily rely on continuous representations. This complicates modeling of discrete properties of the real world such as disjoint object classes between which interpolation is not plausible. In this work, we introduce discrete abstract representations for transformer-based learning (DART), a sample-efficient method utilizing discrete representations for modeling both the world and learning behavior. We incorporate a transformer-decoder for auto-regressive world modeling and a transformer-encoder for learning behavior by attending to task-relevant cues in the discrete representation of the world model. For handling partial observability, we aggregate information from past time steps as memory tokens. DART outperforms previous state-of-the-art methods that do not use look-ahead search on the Atari 100k sample efficiency benchmark with a median human-normalized score of 0.790 and beats humans in 9 out of 26 games. We release our code at https://pranaval.github.io/DART/.
Paper Structure (19 sections, 3 equations, 4 figures, 9 tables, 2 algorithms)

This paper contains 19 sections, 3 equations, 4 figures, 9 tables, 2 algorithms.

Figures (4)

  • Figure 1: Discrete abstract representation for transformer-based learning (DART): In this approach, the original observation $x_t$ is encoded into discrete tokens $z_t$ using VQ-VAE. These tokenized observations, and predicted action, serve as inputs for the world model. A Transformer decoder network is used for modeling the world. The predicted tokens, along with a CLS and a MEM token are used as input by the policy. This policy is modeled using a transformer-encoder network. The CLS token aggregates information from the observation tokens and the MEM token to learn a common representation, which is then used for action and value predictions. This common representation also plays a role in modeling memory, acting as the MEM token at the subsequent time step.
  • Figure 2: Comparison of Mean, Median, and Interquartile Mean Human-Normalized Scores
  • Figure 3: Comparison of different models using performance profiles and probabilities of improvement.
  • Figure 4: Comparison of Memory Requirements Across Atari Games: Atari games exhibit varying memory requirements, depending on their specific dynamics. Games with relatively static or slow-moving objects, like Amidar, maintain complete information at each time step and thus aggregate less information from the memory token. Conversely, games characterized by rapidly changing environments, such as Breakout, Krull, and PrivateEye, require modeling the past trajectories of objects. As a result, the policy for these games heavily relies on the memory token to aggregate information from past states into future states.