Table of Contents
Fetching ...

Model Predictive Control with Differentiable World Models for Offline Reinforcement Learning

Rohan Deb, Stephen J. Wright, Arindam Banerjee

Abstract

Offline Reinforcement Learning (RL) aims to learn optimal policies from fixed offline datasets, without further interactions with the environment. Such methods train an offline policy (or value function), and apply it at inference time without further refinement. We introduce an inference time adaptation framework inspired by model predictive control (MPC) that utilizes a pretrained policy along with a learned world model of state transitions and rewards. While existing world model and diffusion-planning methods use learned dynamics to generate imagined trajectories during training, or to sample candidate plans at inference time, they do not use inference-time information to optimize the policy parameters on the fly. In contrast, our design is a Differentiable World Model (DWM) pipeline that enables endto-end gradient computation through imagined rollouts for policy optimization at inference time based on MPC. We evaluate our algorithm on D4RL continuous-control benchmarks (MuJoCo locomotion tasks and AntMaze), and show that exploiting inference-time information to optimize the policy parameters yields consistent gains over strong offline RL baselines.

Model Predictive Control with Differentiable World Models for Offline Reinforcement Learning

Abstract

Offline Reinforcement Learning (RL) aims to learn optimal policies from fixed offline datasets, without further interactions with the environment. Such methods train an offline policy (or value function), and apply it at inference time without further refinement. We introduce an inference time adaptation framework inspired by model predictive control (MPC) that utilizes a pretrained policy along with a learned world model of state transitions and rewards. While existing world model and diffusion-planning methods use learned dynamics to generate imagined trajectories during training, or to sample candidate plans at inference time, they do not use inference-time information to optimize the policy parameters on the fly. In contrast, our design is a Differentiable World Model (DWM) pipeline that enables endto-end gradient computation through imagined rollouts for policy optimization at inference time based on MPC. We evaluate our algorithm on D4RL continuous-control benchmarks (MuJoCo locomotion tasks and AntMaze), and show that exploiting inference-time information to optimize the policy parameters yields consistent gains over strong offline RL baselines.
Paper Structure (28 sections, 1 theorem, 52 equations, 3 figures, 5 tables, 1 algorithm)

This paper contains 28 sections, 1 theorem, 52 equations, 3 figures, 5 tables, 1 algorithm.

Key Result

Theorem 4.1

Fix a time $t$, a horizon $H$, and a noise sequence $\varepsilon_{t:t+H-1}$, and let $\{(\tilde{s}_j,\tilde{a}_j)\}_{j=0}^{H}$ be defined by eq:mpc_rollout_recursion. Assume $\pi_{\psi}$ is differentiable in $\psi$ and its state input, and $f_{\theta}$, $r_{\xi}$, and $Q_{\phi}$ are differentiable i Moreover, if $f_{\theta}$ is implemented by a reverse diffusion recursion as given by eq:g_h_def th

Figures (3)

  • Figure 1: Inference-time MPC with a diffusion world model. An offline dataset is used to train (i) a policy $\pi_{\psi}$ and terminal critic $Q_{\phi}$ (ii) a reward model $r_{\xi}$, and (iii) a diffusion-based dynamics sampler $f_{\theta}$. At inference time, starting from the current state $s_t$, we unroll multiple imagined rollouts by alternating policy actions $\tilde{a}_h=\pi_{\psi}(\tilde{s}_h)$ and diffusion transitions $\tilde{s}_{h+1}=f_{\theta}(\tilde{s}_h,\tilde{a}_h,\varepsilon_{t+h})$, evaluate a finite-horizon surrogate return (predicted rewards plus terminal value), and backpropagate through the differentiable rollout to update $\psi$ before executing the first action in the real environment. Green arrows indicate the forward rollout; red dashed arrows indicate gradient flow.
  • Figure 2: One-step state prediction RMSE of diffusion models across training steps (20k–200k) on medium-replay datasets. RMSE decreases with training for all three environments (halfcheetah, hopper, walker2d), with shaded regions showing standard error over 1000 transitions.
  • Figure 3: Reward prediction RMSE of reward models across training steps (20k–200k) on medium-replay datasets. All environments show decreasing RMSE with training. Shaded regions indicate standard error over 1000 transitions.

Theorems & Definitions (2)

  • Theorem 4.1: Gradient recursion
  • Remark 4.2