Table of Contents
Fetching ...

Policy-shaped prediction: avoiding distractions in model-based reinforcement learning

Miles Hutson, Isaac Kauvar, Nick Haber

TL;DR

Policy-Shaped Prediction (PSP) tackles distraction sensitivity in image-based model-based RL by shaping the world-model loss with policy gradients, aggregating salience at the object level via a pretrained segmentation model, and suppressing self-generated distractions through an adversarial action-prediction head. Concretely, PSP upweights pixel reconstruction with $\mathcal{L}_{image}(\phi) = \sum_i \frac{\partial a}{\partial x_i} (\hat{x}_i - x_i)^2$ and aggregates salience over segmentation masks using $W_i$ and $W_i''$ with $\alpha=0.9$, while introducing an action-prediction head with $\hat{a}_{t-1}$ and $\mathcal{L}_{AdvHead}$ whose gradients are subtracted from the world-model updates. Empirically, PSP yields about a 2x improvement in robustness to challenging, learnable distractions on Reafferent DMC and Distracting Control while preserving performance on non-distracting benchmarks, demonstrating a practical path to robust, data-efficient model-based control. The work connects gradient-based explainability concepts with segmentation priors and biologically inspired mechanisms to steer world-model capacity toward policy-relevant dynamics, with implications for real-world robotics and intelligent agents.

Abstract

Model-based reinforcement learning (MBRL) is a promising route to sample-efficient policy optimization. However, a known vulnerability of reconstruction-based MBRL consists of scenarios in which detailed aspects of the world are highly predictable, but irrelevant to learning a good policy. Such scenarios can lead the model to exhaust its capacity on meaningless content, at the cost of neglecting important environment dynamics. While existing approaches attempt to solve this problem, we highlight its continuing impact on leading MBRL methods -- including DreamerV3 and DreamerPro -- with a novel environment where background distractions are intricate, predictable, and useless for planning future actions. To address this challenge we develop a method for focusing the capacity of the world model through synergy of a pretrained segmentation model, a task-aware reconstruction loss, and adversarial learning. Our method outperforms a variety of other approaches designed to reduce the impact of distractors, and is an advance towards robust model-based reinforcement learning.

Policy-shaped prediction: avoiding distractions in model-based reinforcement learning

TL;DR

Policy-Shaped Prediction (PSP) tackles distraction sensitivity in image-based model-based RL by shaping the world-model loss with policy gradients, aggregating salience at the object level via a pretrained segmentation model, and suppressing self-generated distractions through an adversarial action-prediction head. Concretely, PSP upweights pixel reconstruction with and aggregates salience over segmentation masks using and with , while introducing an action-prediction head with and whose gradients are subtracted from the world-model updates. Empirically, PSP yields about a 2x improvement in robustness to challenging, learnable distractions on Reafferent DMC and Distracting Control while preserving performance on non-distracting benchmarks, demonstrating a practical path to robust, data-efficient model-based control. The work connects gradient-based explainability concepts with segmentation priors and biologically inspired mechanisms to steer world-model capacity toward policy-relevant dynamics, with implications for real-world robotics and intelligent agents.

Abstract

Model-based reinforcement learning (MBRL) is a promising route to sample-efficient policy optimization. However, a known vulnerability of reconstruction-based MBRL consists of scenarios in which detailed aspects of the world are highly predictable, but irrelevant to learning a good policy. Such scenarios can lead the model to exhaust its capacity on meaningless content, at the cost of neglecting important environment dynamics. While existing approaches attempt to solve this problem, we highlight its continuing impact on leading MBRL methods -- including DreamerV3 and DreamerPro -- with a novel environment where background distractions are intricate, predictable, and useless for planning future actions. To address this challenge we develop a method for focusing the capacity of the world model through synergy of a pretrained segmentation model, a task-aware reconstruction loss, and adversarial learning. Our method outperforms a variety of other approaches designed to reduce the impact of distractors, and is an advance towards robust model-based reinforcement learning.

Paper Structure

This paper contains 23 sections, 4 equations, 17 figures, 4 tables.

Figures (17)

  • Figure 1: Policy-Shaped Prediction in an environment with challenging distractions. (left) Training of an otherwise-unaltered DreamerV3 agent is modified in two ways: 1) A head is added to predict the previous action based on the image encoding, and the gradient of the head is subtracted from the gradient of the image encoder, and 2) the loss is scaled pixelwise by a policy-shaped loss weight. (right) The loss weight uses the gradient of the policy to the input pixels. The image is segmented, and the pixel weights are averaged within each segmented object. Dashed lines signify gradient flow.
  • Figure 2: Schematic of the Reafferent Deepmind Control environment. The distracting background is entirely predictable based on the agent's previous action and the elapsed time in the episode.
  • Figure 3: Training curve comparisons on Reafferent Deepmind Control. Mean $\pm$ std. err.
  • Figure 4: Reconstructed image comparison, PSP vs. DreamerV3 on Reafferent Cheetah Run, same episode and time point. True, reconstructed, difference (true - recon.). DreamerV3 accurately reproduces the background but not the cheetah.
  • Figure 5: Example salience maps (policy-shaped loss weights) highlight the agent.
  • ...and 12 more figures