Mastering Memory Tasks with World Models
Mohammad Reza Samsami, Artem Zholus, Janarthanan Rajendran, Sarath Chandar
TL;DR
This work tackles the challenge of long-term dependencies in model-based RL by introducing Recall to Imagine (R2I), which embeds a Structured State Space Model (S3M) into a DreamerV3-based world model to enable enduring memory and improved long-horizon credit assignment. The core idea is to replace the recurrent posterior with a non-recurrent one to allow parallel imagination, while maintaining a robust temporal model through SSMs. Empirically, R2I delivers state-of-the-art results in memory-intensive domains (BSuite, POPGym, Memory Maze), even surpassing human performance in Memory Maze, and retains competitive performance on Atari and DMC benchmarks, all with up to 9x faster wall-time convergence. The work demonstrates that SSM-based world models can generalize across memory and non-memory domains, offering substantial gains in sample efficiency and memory capability for RL systems.
Abstract
Current model-based reinforcement learning (MBRL) agents struggle with long-term dependencies. This limits their ability to effectively solve tasks involving extended time gaps between actions and outcomes, or tasks demanding the recalling of distant observations to inform current actions. To improve temporal coherence, we integrate a new family of state space models (SSMs) in world models of MBRL agents to present a new method, Recall to Imagine (R2I). This integration aims to enhance both long-term memory and long-horizon credit assignment. Through a diverse set of illustrative tasks, we systematically demonstrate that R2I not only establishes a new state-of-the-art for challenging memory and credit assignment RL tasks, such as BSuite and POPGym, but also showcases superhuman performance in the complex memory domain of Memory Maze. At the same time, it upholds comparable performance in classic RL tasks, such as Atari and DMC, suggesting the generality of our method. We also show that R2I is faster than the state-of-the-art MBRL method, DreamerV3, resulting in faster wall-time convergence.
