Table of Contents
Fetching ...

Drama: Mamba-Enabled Model-Based Reinforcement Learning Is Sample and Parameter Efficient

Wenlong Wang, Ivana Dusparic, Yucheng Shi, Ke Zhang, Vinny Cahill

TL;DR

Drama introduces a Mamba-based state-space world model for model-based reinforcement learning, achieving linear ($O(n)$) memory and computation to efficiently handle long sequences. The architecture couples a discrete latent VAE with a Mamba-2 sequence model and a lightweight reward/termination head, and uses Dynamic Frequency-Based Sampling to mitigate early-model suboptimality during imagination-based policy learning. Empirical results on Atari100k show competitive performance with a 7M-parameter world model, and ablations demonstrate DFS effectiveness and Mamba-2 advantages over Mamba in several games, along with favorable long-sequence handling. The work provides a practical, hardware-friendly alternative to transformer-based world models, offering strong parameter efficiency and potential for improved exploration and long-horizon planning in real-world RL settings.

Abstract

Model-based reinforcement learning (RL) offers a solution to the data inefficiency that plagues most model-free RL algorithms. However, learning a robust world model often requires complex and deep architectures, which are computationally expensive and challenging to train. Within the world model, sequence models play a critical role in accurate predictions, and various architectures have been explored, each with its own challenges. Currently, recurrent neural network (RNN)-based world models struggle with vanishing gradients and capturing long-term dependencies. Transformers, on the other hand, suffer from the quadratic memory and computational complexity of self-attention mechanisms, scaling as $O(n^2)$, where $n$ is the sequence length. To address these challenges, we propose a state space model (SSM)-based world model, Drama, specifically leveraging Mamba, that achieves $O(n)$ memory and computational complexity while effectively capturing long-term dependencies and enabling efficient training with longer sequences. We also introduce a novel sampling method to mitigate the suboptimality caused by an incorrect world model in the early training stages. Combining these techniques, Drama achieves a normalised score on the Atari100k benchmark that is competitive with other state-of-the-art (SOTA) model-based RL algorithms, using only a 7 million-parameter world model. Drama is accessible and trainable on off-the-shelf hardware, such as a standard laptop. Our code is available at https://github.com/realwenlongwang/Drama.git.

Drama: Mamba-Enabled Model-Based Reinforcement Learning Is Sample and Parameter Efficient

TL;DR

Drama introduces a Mamba-based state-space world model for model-based reinforcement learning, achieving linear () memory and computation to efficiently handle long sequences. The architecture couples a discrete latent VAE with a Mamba-2 sequence model and a lightweight reward/termination head, and uses Dynamic Frequency-Based Sampling to mitigate early-model suboptimality during imagination-based policy learning. Empirical results on Atari100k show competitive performance with a 7M-parameter world model, and ablations demonstrate DFS effectiveness and Mamba-2 advantages over Mamba in several games, along with favorable long-sequence handling. The work provides a practical, hardware-friendly alternative to transformer-based world models, offering strong parameter efficiency and potential for improved exploration and long-horizon planning in real-world RL settings.

Abstract

Model-based reinforcement learning (RL) offers a solution to the data inefficiency that plagues most model-free RL algorithms. However, learning a robust world model often requires complex and deep architectures, which are computationally expensive and challenging to train. Within the world model, sequence models play a critical role in accurate predictions, and various architectures have been explored, each with its own challenges. Currently, recurrent neural network (RNN)-based world models struggle with vanishing gradients and capturing long-term dependencies. Transformers, on the other hand, suffer from the quadratic memory and computational complexity of self-attention mechanisms, scaling as , where is the sequence length. To address these challenges, we propose a state space model (SSM)-based world model, Drama, specifically leveraging Mamba, that achieves memory and computational complexity while effectively capturing long-term dependencies and enabling efficient training with longer sequences. We also introduce a novel sampling method to mitigate the suboptimality caused by an incorrect world model in the early training stages. Combining these techniques, Drama achieves a normalised score on the Atari100k benchmark that is competitive with other state-of-the-art (SOTA) model-based RL algorithms, using only a 7 million-parameter world model. Drama is accessible and trainable on off-the-shelf hardware, such as a standard laptop. Our code is available at https://github.com/realwenlongwang/Drama.git.

Paper Structure

This paper contains 31 sections, 8 equations, 10 figures, 8 tables, 1 algorithm.

Figures (10)

  • Figure 1: Drama world model architecture. At each sequence index $t$, the raw game frames are encoded into ${\bm{z}}_t$ and combined with the action $a_t$ as input to the Mamba blocks. The input channel dimension is divided by the head dimension $p$ to generate the deterministic recurrent state $d_t$. This recurrent state $d_t$ is used to predict the next embedding $\hat{{\bm{z}}}_{t+1}$, reward $\hat{r}_t$, and termination flag $\hat{e}_t$, which represent the outcomes based on the current frame and action. The decoder reconstructs the original frame from the encoded embeddings ${\bm{z}}_t$ rather than from the predicted embeddings $\hat{{\bm{z}}}_t$. The Mamba-2 block employs a semi-separable matrix structure, which can be decomposed into $q \times q$ sub-matrices, enabling more efficient computation and processing.
  • Figure 2: Mamba vs. Mamba-2. Mamba2 has shown a superior performance to Mamba in three out of four games. Both Mamba and Mamba-2 use DFS in this experiment.
  • Figure 3: Illustrations of the grid world environment and its reconstruction into a sequential format. (a) Sequence of consecutive frames in the grid world environment. The Example presents a sequence of consecutive frames, arranged from left to right. Each frame represents a $5 \times 5$ grid, where the outer 16 cells are black walls, and the central $3 \times 3$ grid is the reachable space. The red cell is the controllable agent, which moves according to a random action, and the yellow cell is a fixed goal. The sequence of frames, from left to right, illustrates the movement of the agent following the action sequence: $east \rightarrow south \rightarrow east \rightarrow north$. Once the yellow cell is reached by the agent, the location of the agent and goal will be reset randomly. (b) Reconstructing the grid world into a long sequence. Each grey-shaded box contains 25 flattened grid tokens and one action token.
  • Figure 4: Atari100k Learning Curve. This figure compares the performance of DramaXS (10 million parameters) and DreamerV3XS (12 million parameters) on the Atari100k benchmark. DramaXS outperforms DreamerV3XS in most games. Exceptions include PrivateEye and Qbert , where DreamerV3XS performs better.
  • Figure 5: Uniform Sampling vs. DFS Learning Curve. DFS outperforms uniform sampling in 11 games (e.g., Asterix , BankHeist , Krull ), underperforms in 2 games (Breakout , KungFuMaster ), and matches performance in 13 games . The normalised mean score of DFS (105% ) surpasses uniform sampling (80% ), while the normalised median is comparable (27% vs. 28% ). DFS demonstrates stronger performance in games requiring exploiting the opponents' strategy (e.g., Pong , Boxing ) but struggles in environments with early-stage dynamics (Breakout).
  • ...and 5 more figures