Table of Contents
Fetching ...

Dreaming of Many Worlds: Learning Contextual World Models Aids Zero-Shot Generalization

Sai Prasanna, Karim Farid, Raghu Rajan, André Biedenkapp

TL;DR

This work tackles zero-shot generalization in model-based RL by leveraging observable contextual information. It introduces a Contextual Recurrent State-Space Model (cRSSM) that extends the Dreamer framework to condition latent dynamics and observations on context, enabling the agent to imagine trajectories under counterfactual contexts. Through experiments on CARL tasks, the authors demonstrate that explicit context conditioning (particularly via cRSSM) improves zero-shot generalization, with latent disentanglement and counterfactual dreaming offering qualitative insights into the learned representations. The study highlights the value of principled context integration over naive methods and points to future work on relaxing the assumption of observable context and expanding benchmarking for ZSG in contextual RL.

Abstract

Zero-shot generalization (ZSG) to unseen dynamics is a major challenge for creating generally capable embodied agents. To address the broader challenge, we start with the simpler setting of contextual reinforcement learning (cRL), assuming observability of the context values that parameterize the variation in the system's dynamics, such as the mass or dimensions of a robot, without making further simplifying assumptions about the observability of the Markovian state. Toward the goal of ZSG to unseen variation in context, we propose the contextual recurrent state-space model (cRSSM), which introduces changes to the world model of Dreamer (v3) (Hafner et al., 2023). This allows the world model to incorporate context for inferring latent Markovian states from the observations and modeling the latent dynamics. Our approach is evaluated on two tasks from the CARL benchmark suite, which is tailored to study contextual RL. Our experiments show that such systematic incorporation of the context improves the ZSG of the policies trained on the "dreams" of the world model. We further find qualitatively that our approach allows Dreamer to disentangle the latent state from context, allowing it to extrapolate its dreams to the many worlds of unseen contexts. The code for all our experiments is available at https://github.com/sai-prasanna/dreaming_of_many_worlds.

Dreaming of Many Worlds: Learning Contextual World Models Aids Zero-Shot Generalization

TL;DR

This work tackles zero-shot generalization in model-based RL by leveraging observable contextual information. It introduces a Contextual Recurrent State-Space Model (cRSSM) that extends the Dreamer framework to condition latent dynamics and observations on context, enabling the agent to imagine trajectories under counterfactual contexts. Through experiments on CARL tasks, the authors demonstrate that explicit context conditioning (particularly via cRSSM) improves zero-shot generalization, with latent disentanglement and counterfactual dreaming offering qualitative insights into the learned representations. The study highlights the value of principled context integration over naive methods and points to future work on relaxing the assumption of observable context and expanding benchmarking for ZSG in contextual RL.

Abstract

Zero-shot generalization (ZSG) to unseen dynamics is a major challenge for creating generally capable embodied agents. To address the broader challenge, we start with the simpler setting of contextual reinforcement learning (cRL), assuming observability of the context values that parameterize the variation in the system's dynamics, such as the mass or dimensions of a robot, without making further simplifying assumptions about the observability of the Markovian state. Toward the goal of ZSG to unseen variation in context, we propose the contextual recurrent state-space model (cRSSM), which introduces changes to the world model of Dreamer (v3) (Hafner et al., 2023). This allows the world model to incorporate context for inferring latent Markovian states from the observations and modeling the latent dynamics. Our approach is evaluated on two tasks from the CARL benchmark suite, which is tailored to study contextual RL. Our experiments show that such systematic incorporation of the context improves the ZSG of the policies trained on the "dreams" of the world model. We further find qualitatively that our approach allows Dreamer to disentangle the latent state from context, allowing it to extrapolate its dreams to the many worlds of unseen contexts. The code for all our experiments is available at https://github.com/sai-prasanna/dreaming_of_many_worlds.
Paper Structure (44 sections, 4 equations, 43 figures, 5 tables)

This paper contains 44 sections, 4 equations, 43 figures, 5 tables.

Figures (43)

  • Figure 1: Latent dynamics models. The models shown observe the first two time steps and predict the third. Circles represent stochastic variables, and squares represent deterministic variables. Solid lines denote the generative process, and dashed lines denote the inference model. The context node and edges are highlighted in red. (\ref{['fig:cMDP-gen']}) The generative model for a cMDP. (\ref{['fig:cRssm']}) Our cRSSM.
  • Figure 2: Training contexts and evaluation regions for single and dual context variation.
  • Figure 3: Generalization capabilities of Dreamer with pixel-observations when varying the pole length in CartPole. The y-axis indicates the gained reward, and the x-axis the pole length. The blue bars shows vanilla Dreamers performance when only training on the default length ($0.50$) and extrapolating to other settings. The shaded are gives the training range for the methods using context. Expert and random policies give upper and lower bound for performances in each context.
  • Figure 4: Aggregated comparison (across contexts & tasks) across cRSSM, concat-context, hidden-context, and default-context for the evaluation settings: Interpolation, Extrapolation, and Inter+Extrapolation using IQM over expert normalized scores for both input modalities. Intervals shown are stratified bootstrap 95% confidence intervals over seeds & aggregated contexts.
  • Figure 5: Extrapolated Contexts
  • ...and 38 more figures