Table of Contents
Fetching ...

World Models via Policy-Guided Trajectory Diffusion

Marc Rigter, Jun Yamada, Ingmar Posner

TL;DR

PolyGRAD introduces a diffusion-based world model that generates entire on-policy trajectories non-autoregressively. It combines a denoising network with policy-guided action updates to produce trajectories aligned with the current policy, linking diffusion to score-based and classifier-guided perspectives. Empirical results show strong short-horizon trajectory prediction and competitive RL performance with imagined data in MuJoCo, while noting stability and entropy challenges. This work offers a new paradigm for accurate, non-autoregressive world modelling and suggests directions for latent diffusion and longer-horizon analysis.

Abstract

World models are a powerful tool for developing intelligent agents. By predicting the outcome of a sequence of actions, world models enable policies to be optimised via on-policy reinforcement learning (RL) using synthetic data, i.e. in "in imagination". Existing world models are autoregressive in that they interleave predicting the next state with sampling the next action from the policy. Prediction error inevitably compounds as the trajectory length grows. In this work, we propose a novel world modelling approach that is not autoregressive and generates entire on-policy trajectories in a single pass through a diffusion model. Our approach, Policy-Guided Trajectory Diffusion (PolyGRAD), leverages a denoising model in addition to the gradient of the action distribution of the policy to diffuse a trajectory of initially random states and actions into an on-policy synthetic trajectory. We analyse the connections between PolyGRAD, score-based generative models, and classifier-guided diffusion models. Our results demonstrate that PolyGRAD outperforms state-of-the-art baselines in terms of trajectory prediction error for short trajectories, with the exception of autoregressive diffusion. For short trajectories, PolyGRAD obtains similar errors to autoregressive diffusion, but with lower computational requirements. For long trajectories, PolyGRAD obtains comparable performance to baselines. Our experiments demonstrate that PolyGRAD enables performant policies to be trained via on-policy RL in imagination for MuJoCo continuous control domains. Thus, PolyGRAD introduces a new paradigm for accurate on-policy world modelling without autoregressive sampling.

World Models via Policy-Guided Trajectory Diffusion

TL;DR

PolyGRAD introduces a diffusion-based world model that generates entire on-policy trajectories non-autoregressively. It combines a denoising network with policy-guided action updates to produce trajectories aligned with the current policy, linking diffusion to score-based and classifier-guided perspectives. Empirical results show strong short-horizon trajectory prediction and competitive RL performance with imagined data in MuJoCo, while noting stability and entropy challenges. This work offers a new paradigm for accurate, non-autoregressive world modelling and suggests directions for latent diffusion and longer-horizon analysis.

Abstract

World models are a powerful tool for developing intelligent agents. By predicting the outcome of a sequence of actions, world models enable policies to be optimised via on-policy reinforcement learning (RL) using synthetic data, i.e. in "in imagination". Existing world models are autoregressive in that they interleave predicting the next state with sampling the next action from the policy. Prediction error inevitably compounds as the trajectory length grows. In this work, we propose a novel world modelling approach that is not autoregressive and generates entire on-policy trajectories in a single pass through a diffusion model. Our approach, Policy-Guided Trajectory Diffusion (PolyGRAD), leverages a denoising model in addition to the gradient of the action distribution of the policy to diffuse a trajectory of initially random states and actions into an on-policy synthetic trajectory. We analyse the connections between PolyGRAD, score-based generative models, and classifier-guided diffusion models. Our results demonstrate that PolyGRAD outperforms state-of-the-art baselines in terms of trajectory prediction error for short trajectories, with the exception of autoregressive diffusion. For short trajectories, PolyGRAD obtains similar errors to autoregressive diffusion, but with lower computational requirements. For long trajectories, PolyGRAD obtains comparable performance to baselines. Our experiments demonstrate that PolyGRAD enables performant policies to be trained via on-policy RL in imagination for MuJoCo continuous control domains. Thus, PolyGRAD introduces a new paradigm for accurate on-policy world modelling without autoregressive sampling.
Paper Structure (45 sections, 16 equations, 18 figures, 4 tables, 3 algorithms)

This paper contains 45 sections, 16 equations, 18 figures, 4 tables, 3 algorithms.

Figures (18)

  • Figure 1: Top: Illustration of Policy-Guided Trajectory Diffusion (PolyGRAD). PolyGRAD starts with a trajectory of random states and actions and diffuses it into an on-policy trajectory using a learnt denoising model, $\epsilon$, and the policy, $\pi$. Bottom: Training a standard diffusion model on trajectories can be used to generate synthetic trajectories, but these are not on-policy.
  • Figure 2: Step-by-step illustration of PolyGRAD trajectory generation (Algorithm \ref{['alg:polygrad']}). Bottom left: Illustration of policy action distribution throughout state space. Top left: Trajectory is initialised with random states and actions (Line \ref{['algline:2']} and \ref{['algline:3']}). The initial state (dark purple) is sampled from the dataset (Line \ref{['algline:init_state']}). Solid black arrows: Action sequence is updated according to score of policy conditioned on current state sequence (Line \ref{['algline:8']}). Hollow black arrows: State sequence is updated conditioned on current actions (Line \ref{['algline:state_update']}). Bottom right: Final trajectory is returned.
  • Figure 3: Plots of action distributions produced by PolyGRAD. Blue line illustrates the distribution of $a - \mu_\phi(s)$ for a batch of synthetic data. Each subplot is for a policy with a different entropy level that is constant throughout the state space. Dashed black line indicates the action distribution output by the policy. Data is generated by running Algorithm \ref{['alg:rl']} in Walker2d with $h = 50$.
  • Figure 4: Computation times to produce a batch of 1000 trajectories on a V100 GPU in Walker2d.
  • Figure 5: Plots of mean squared error (MSE) of predicted states vs ground truth states for each world model trained on the same dataset for each environment. Shaded regions indicate standard deviation over 5 seeds. For PolyGRAD, we use the transformer denoising network trained on trajectories of length $h=10, 50,$ or $200$.
  • ...and 13 more figures