Do Transformer World Models Give Better Policy Gradients?
Michel Ma, Tianwei Ni, Clement Gehring, Pierluca D'Oro, Pierre-Luc Bacon
TL;DR
The paper tackles the challenge of long-horizon policy optimization with differentiable world models, showing that naive transformer-based history models can create circuitous gradient paths that destabilize learning. It introduces Actions World Models (AWMs), which condition only on initial state and past actions, thereby enabling direct gradient flow from rewards to actions; it provides theoretical bounds showing that gradient dynamics depend on the underlying architecture (e.g., $O(\eta^H)$ for RNNs vs. $O(H^3)$ for self-attention). Empirically, transformer AWMs yield smoother, more navigable policy optimization landscapes and achieve superior long-horizon performance on Myriad benchmarks compared to both model-based and model-free baselines, even outperforming the true differentiable simulator in some cases. The work links neural-network architecture design to policy-gradient updates, suggesting a scalable path for improved long-horizon model-based RL through architecture-aware gradient propagation.
Abstract
A natural approach for reinforcement learning is to predict future rewards by unrolling a neural network world model, and to backpropagate through the resulting computational graph to learn a policy. However, this method often becomes impractical for long horizons since typical world models induce hard-to-optimize loss landscapes. Transformers are known to efficiently propagate gradients over long horizons: could they be the solution to this problem? Surprisingly, we show that commonly-used transformer world models produce circuitous gradient paths, which can be detrimental to long-range policy gradients. To tackle this challenge, we propose a class of world models called Actions World Models (AWMs), designed to provide more direct routes for gradient propagation. We integrate such AWMs into a policy gradient framework that underscores the relationship between network architectures and the policy gradient updates they inherently represent. We demonstrate that AWMs can generate optimization landscapes that are easier to navigate even when compared to those from the simulator itself. This property allows transformer AWMs to produce better policies than competitive baselines in realistic long-horizon tasks.
