Masked World Models for Visual Control
Younggyo Seo, Danijar Hafner, Hao Liu, Fangchen Liu, Stephen James, Kimin Lee, Pieter Abbeel
TL;DR
<3-5 sentence high-level summary>MWM tackles the challenge of sample-efficient visual model-based reinforcement learning by decoupling visual representation learning from dynamics learning. It uses a convolutional-feature-mMasked autoencoder with an auxiliary reward-prediction objective and trains a latent dynamics model on the learned representations, all updated online. Empirically, it achieves state-of-the-art results on challenging visual robotic tasks across Meta-world and RLBench, outperforming DreamerV2, and demonstrates that convolutional feature masking can outperform patch-based MAE. The work also provides qualitative insights into how reward-guided representations and task-focused latent dynamics improve prediction of relevant objects and actions, pointing to broader potential in multi-modal and temporally rich extensions.
Abstract
Visual model-based reinforcement learning (RL) has the potential to enable sample-efficient robot learning from visual observations. Yet the current approaches typically train a single model end-to-end for learning both visual representations and dynamics, making it difficult to accurately model the interaction between robots and small objects. In this work, we introduce a visual model-based RL framework that decouples visual representation learning and dynamics learning. Specifically, we train an autoencoder with convolutional layers and vision transformers (ViT) to reconstruct pixels given masked convolutional features, and learn a latent dynamics model that operates on the representations from the autoencoder. Moreover, to encode task-relevant information, we introduce an auxiliary reward prediction objective for the autoencoder. We continually update both autoencoder and dynamics model using online samples collected from environment interaction. We demonstrate that our decoupling approach achieves state-of-the-art performance on a variety of visual robotic tasks from Meta-world and RLBench, e.g., we achieve 81.7% success rate on 50 visual robotic manipulation tasks from Meta-world, while the baseline achieves 67.9%. Code is available on the project website: https://sites.google.com/view/mwm-rl.
