Accelerating Model-Based Reinforcement Learning with State-Space World Models
Maria Krinner, Elie Aljalbout, Angel Romero, Davide Scaramuzza
TL;DR
Model-based RL often yields better sample efficiency but suffers from slow training due to sequential world-model updates. The authors introduce S5WM, a state-space world model that replaces recurrent RSSMs with the parallelizable S5 architecture and leverages privileged state information during training, achieving substantial speedups without sacrificing performance. Across state-based and vision-based quadrotor tasks, S5WM matches or exceeds DreamerV3 in task rewards and sample efficiency while delivering up to $4\times$ faster overall training and up to $10\times$ faster world-model training. The work demonstrates strong sim-to-real transfer on agile drone tasks and provides a roadmap for faster, more practical MBRL in real-world robotics.
Abstract
Reinforcement learning (RL) is a powerful approach for robot learning. However, model-free RL (MFRL) requires a large number of environment interactions to learn successful control policies. This is due to the noisy RL training updates and the complexity of robotic systems, which typically involve highly non-linear dynamics and noisy sensor signals. In contrast, model-based RL (MBRL) not only trains a policy but simultaneously learns a world model that captures the environment's dynamics and rewards. The world model can either be used for planning, for data collection, or to provide first-order policy gradients for training. Leveraging a world model significantly improves sample efficiency compared to model-free RL. However, training a world model alongside the policy increases the computational complexity, leading to longer training times that are often intractable for complex real-world scenarios. In this work, we propose a new method for accelerating model-based RL using state-space world models. Our approach leverages state-space models (SSMs) to parallelize the training of the dynamics model, which is typically the main computational bottleneck. Additionally, we propose an architecture that provides privileged information to the world model during training, which is particularly relevant for partially observable environments. We evaluate our method in several real-world agile quadrotor flight tasks, involving complex dynamics, for both fully and partially observable environments. We demonstrate a significant speedup, reducing the world model training time by up to 10 times, and the overall MBRL training time by up to 4 times. This benefit comes without compromising performance, as our method achieves similar sample efficiency and task rewards to state-of-the-art MBRL methods.
