Learning to Play Atari in a World of Tokens
Pranav Agarwal, Sheldon Andrews, Samira Ebrahimi Kahou
TL;DR
This paper tackles sample-inefficient model-based RL by introducing DART, a transformer-based framework that operates on discrete latent representations. It employs a transformer-decoder to model world dynamics from tokenized observations and a transformer-encoder to learn policy behavior, augmented with a memory token to address partial observability. DART delivers state-of-the-art Atari 100k performance among no-look-ahead methods, achieving a median human-normalized score of $0.790$ and superhuman results in $9$ of $26$ games, with strong interpretability via attention analyses and targeted ablations. The work suggests that discrete representations and memory-augmented transformers can yield both improved sample efficiency and clearer behavioral insights, while outlining extensions to lookahead planning and continuous-action spaces for broader applicability.
Abstract
Model-based reinforcement learning agents utilizing transformers have shown improved sample efficiency due to their ability to model extended context, resulting in more accurate world models. However, for complex reasoning and planning tasks, these methods primarily rely on continuous representations. This complicates modeling of discrete properties of the real world such as disjoint object classes between which interpolation is not plausible. In this work, we introduce discrete abstract representations for transformer-based learning (DART), a sample-efficient method utilizing discrete representations for modeling both the world and learning behavior. We incorporate a transformer-decoder for auto-regressive world modeling and a transformer-encoder for learning behavior by attending to task-relevant cues in the discrete representation of the world model. For handling partial observability, we aggregate information from past time steps as memory tokens. DART outperforms previous state-of-the-art methods that do not use look-ahead search on the Atari 100k sample efficiency benchmark with a median human-normalized score of 0.790 and beats humans in 9 out of 26 games. We release our code at https://pranaval.github.io/DART/.
