Table of Contents
Fetching ...

Contextual Latent World Models for Offline Meta Reinforcement Learning

Mohammadreza Nakheai, Aidan Scannell, Kevin Luck, Joni Pajarinen

TL;DR

Context contextual latent world models are introduced, which condition latent world models on inferred task representations and train them jointly with the context encoder, yielding task representations that capture task-dependent dynamics rather than merely discriminating between tasks.

Abstract

Offline meta-reinforcement learning seeks to learn policies that generalize across related tasks from fixed datasets. Context-based methods infer a task representation from transition histories, but learning effective task representations without supervision remains a challenge. In parallel, latent world models have demonstrated strong self-supervised representation learning through temporal consistency. We introduce contextual latent world models, which condition latent world models on inferred task representations and train them jointly with the context encoder. This enforces task-conditioned temporal consistency, yielding task representations that capture task-dependent dynamics rather than merely discriminating between tasks. Our method learns more expressive task representations and significantly improves generalization to unseen tasks across MuJoCo, Contextual-DeepMind Control, and Meta-World benchmarks.

Contextual Latent World Models for Offline Meta Reinforcement Learning

TL;DR

Context contextual latent world models are introduced, which condition latent world models on inferred task representations and train them jointly with the context encoder, yielding task representations that capture task-dependent dynamics rather than merely discriminating between tasks.

Abstract

Offline meta-reinforcement learning seeks to learn policies that generalize across related tasks from fixed datasets. Context-based methods infer a task representation from transition histories, but learning effective task representations without supervision remains a challenge. In parallel, latent world models have demonstrated strong self-supervised representation learning through temporal consistency. We introduce contextual latent world models, which condition latent world models on inferred task representations and train them jointly with the context encoder. This enforces task-conditioned temporal consistency, yielding task representations that capture task-dependent dynamics rather than merely discriminating between tasks. Our method learns more expressive task representations and significantly improves generalization to unseen tasks across MuJoCo, Contextual-DeepMind Control, and Meta-World benchmarks.
Paper Structure (62 sections, 6 theorems, 32 equations, 10 figures, 15 tables, 1 algorithm)

This paper contains 62 sections, 6 theorems, 32 equations, 10 figures, 15 tables, 1 algorithm.

Key Result

Theorem 3.1

For any task $i$ and any latent policy $\pi$,

Figures (10)

  • Figure 1: Method overview.Left: A context encoder $E_\theta$ maps transitions from each task to a task representation $\mathbf{z}$, which serves as an implicit task identifier. Middle: An observation encoder maps observations $\mathbf{s}_t$ to discrete latent vectors $\mathbf{c}_t$ using finite scalar quantization (FSQ). A task-conditioned latent dynamics model $D_\phi$ predicts future discrete latent states given the current discrete latent state, action, and task representation. The world model is trained using a classification loss based on temporal consistency. Right: An offline policy is trained using the discrete latent states $\mathbf{c}_t$ and task representation $\mathbf{z}$.
  • Figure 2: Representation metrics (Feature Rank, Rank, and Dormant ratio) of the context encoder. SPC maintains low dormant neuron ratio and high matrix rank while learning more diverse features (higher feature rank) compared to pure reconstruction (UNICORN-SUP). The shaded area represents 95% confidence interval across 6 random seeds.
  • Figure 3: Disentanglement metrics (DCI, InfoMEC) for the Cheetah-length-speed (Ls) environment. Latent world models disentangle the variation factors more effectively, while contrastive learning enhances task distinguishability, reflected in informativeness and explicitness. TC denotes training the context encoder solely with the latent temporal consistency objective (\ref{['eq:worldmodel_objective']}); FOCAL and InfoNCE represent two contrastive objectives, and UNICORN-SUP indicates training with reconstruction. Averaged over 6 random seeds.
  • Figure 4: Different world model formulations: IQM and optimality gap of the normalized return for different world modeling methods, evaluated on 9 environments with 6 random seeds per environment. The main advantage of discretizing the latent space is due to classification loss (cross entropy). Bounding or discretizing the latent space alone does not improve performance.
  • Figure 5: Few-shot in-distribution performance on MuJoCo and Contextual DMC benchmarks: IQM and optimality gap of the normalized return (9 environments, each with 6 random seeds).
  • ...and 5 more figures

Theorems & Definitions (11)

  • Theorem 3.1: Value error from abstraction, model learning, and task inference
  • Lemma 1: Simulation lemma (TV form)
  • proof
  • Lemma 2: Original MDP $\rightarrow$ latent Markov MDP
  • proof
  • Lemma 3: Latent Markov MDP $\rightarrow$ learned world model (true task)
  • proof
  • Theorem 3.1: Original task MDP vs learned latent world model (true task)
  • proof
  • Lemma 4: Extra value error from task inference in the learned latent MDP
  • ...and 1 more