Table of Contents
Fetching ...

Do Transformer World Models Give Better Policy Gradients?

Michel Ma, Tianwei Ni, Clement Gehring, Pierluca D'Oro, Pierre-Luc Bacon

TL;DR

The paper tackles the challenge of long-horizon policy optimization with differentiable world models, showing that naive transformer-based history models can create circuitous gradient paths that destabilize learning. It introduces Actions World Models (AWMs), which condition only on initial state and past actions, thereby enabling direct gradient flow from rewards to actions; it provides theoretical bounds showing that gradient dynamics depend on the underlying architecture (e.g., $O(\eta^H)$ for RNNs vs. $O(H^3)$ for self-attention). Empirically, transformer AWMs yield smoother, more navigable policy optimization landscapes and achieve superior long-horizon performance on Myriad benchmarks compared to both model-based and model-free baselines, even outperforming the true differentiable simulator in some cases. The work links neural-network architecture design to policy-gradient updates, suggesting a scalable path for improved long-horizon model-based RL through architecture-aware gradient propagation.

Abstract

A natural approach for reinforcement learning is to predict future rewards by unrolling a neural network world model, and to backpropagate through the resulting computational graph to learn a policy. However, this method often becomes impractical for long horizons since typical world models induce hard-to-optimize loss landscapes. Transformers are known to efficiently propagate gradients over long horizons: could they be the solution to this problem? Surprisingly, we show that commonly-used transformer world models produce circuitous gradient paths, which can be detrimental to long-range policy gradients. To tackle this challenge, we propose a class of world models called Actions World Models (AWMs), designed to provide more direct routes for gradient propagation. We integrate such AWMs into a policy gradient framework that underscores the relationship between network architectures and the policy gradient updates they inherently represent. We demonstrate that AWMs can generate optimization landscapes that are easier to navigate even when compared to those from the simulator itself. This property allows transformer AWMs to produce better policies than competitive baselines in realistic long-horizon tasks.

Do Transformer World Models Give Better Policy Gradients?

TL;DR

The paper tackles the challenge of long-horizon policy optimization with differentiable world models, showing that naive transformer-based history models can create circuitous gradient paths that destabilize learning. It introduces Actions World Models (AWMs), which condition only on initial state and past actions, thereby enabling direct gradient flow from rewards to actions; it provides theoretical bounds showing that gradient dynamics depend on the underlying architecture (e.g., for RNNs vs. for self-attention). Empirically, transformer AWMs yield smoother, more navigable policy optimization landscapes and achieve superior long-horizon performance on Myriad benchmarks compared to both model-based and model-free baselines, even outperforming the true differentiable simulator in some cases. The work links neural-network architecture design to policy-gradient updates, suggesting a scalable path for improved long-horizon model-based RL through architecture-aware gradient propagation.

Abstract

A natural approach for reinforcement learning is to predict future rewards by unrolling a neural network world model, and to backpropagate through the resulting computational graph to learn a policy. However, this method often becomes impractical for long horizons since typical world models induce hard-to-optimize loss landscapes. Transformers are known to efficiently propagate gradients over long horizons: could they be the solution to this problem? Surprisingly, we show that commonly-used transformer world models produce circuitous gradient paths, which can be detrimental to long-range policy gradients. To tackle this challenge, we propose a class of world models called Actions World Models (AWMs), designed to provide more direct routes for gradient propagation. We integrate such AWMs into a policy gradient framework that underscores the relationship between network architectures and the policy gradient updates they inherently represent. We demonstrate that AWMs can generate optimization landscapes that are easier to navigate even when compared to those from the simulator itself. This property allows transformer AWMs to produce better policies than competitive baselines in realistic long-horizon tasks.
Paper Structure (26 sections, 10 theorems, 40 equations, 11 figures, 2 tables, 1 algorithm)

This paper contains 26 sections, 10 theorems, 40 equations, 11 figures, 2 tables, 1 algorithm.

Key Result

Theorem 1

Let the gradient norm of $h$ with respect to its inputs be bounded by $L_a$ and $L_s$: $\|\frac{\partial h(\hat{s}_{1:t}, a_{1:t})}{\partial a_k}\| \leq L_a$ and $\| \frac{\partial h(\hat{s}_{1:t}, a_{1:t})}{\partial \hat{s}_i}\| \leq L_s$ for all $s_{1:t}, a_{1:t}, k, i$. Let $r$ be the $L_r$-Lipsc

Figures (11)

  • Figure 1: Diagram illustrating gradient flows through different world model types from states to actions. Circuitous (longer than necessary) gradient paths go through connections highlighted in red. An Actions World Model has no circuitous gradient paths, allowing gradients to directly flow from states to actions through a single application of a world model.
  • Figure 2: Transformer AWMs outperforms all BPO baselines in chaotic environments. Final performances of BPO with different world models on the double-pendulum environment (10 seeds $\pm$ std).
  • Figure 3: AWMs ignore non-differentiable points in the state space. (a) After the block is pushed with some initial action, it bounces off the wall, instantaneously reversing its velocity. (b) Visualization of the point of non-differentiability in the state space. (c) Learning a Markovian model or a HWM causes catastrophic compounding errors, but an AWM can still accurately model the final reward when varying the initial action. Learned dynamics are trained offline on a dataset collected using random actions.
  • Figure 4: Transformer AWMs smooths out chaotic dynamics. (a) A double-pendulum environment where an initial position must be chosen in order to achieve some pre-determined goal state after $H$ steps. Different transition models are learned on a data set of random trajectories. (b) The mean gradient norm of the final state with respect to the initial action for each model is computed over 50 different random actions for different horizons. (c) Final return according to different models with respect to different initial actions for $H=100$.
  • Figure 5: Policy optimization with transformer AWMs gives better policies for long horizons. (a) Final performance of BPO with different world models on Myriad (10 seeds $\pm$ 95% C.I.). (b) Learning curves of BPO through a transformer AWM, a SAC agent, and an Online-DT agent on 20 and 100 length horizons (10 seeds $\pm$ 95% C.I.).
  • ...and 6 more figures

Theorems & Definitions (18)

  • Theorem 1
  • Proposition 1
  • Theorem 2
  • Corollary 2.1
  • Corollary 2.2
  • Theorem 2
  • proof
  • Proposition 1
  • proof
  • Theorem 2
  • ...and 8 more