Table of Contents
Fetching ...

Objects matter: object-centric world models improve reinforcement learning in visually complex environments

Weipu Zhang, Adam Jelley, Trevor McInroe, Amos Storkey

TL;DR

This paper tackles the problem of sample-inefficient deep reinforcement learning in visually complex environments by proposing OC-STORM, an object-centric MBRL pipeline. It integrates a frozen vision model (Cutie) to extract compact object features, combines these with raw observations via a spatial-temporal transformer, and trains policies on imagined trajectories generated by a world model that uses a categorically latent representation. OC-STORM demonstrates improved sample efficiency over STORM on Atari 100k (18/26 tasks) and achieves faster convergence and stronger performance on Hollow Knight bosses, highlighting practical gains in visually rich settings. The work also analyzes the completeness and usefulness of object representations and discusses limitations related to duplicate objects and background information, suggesting directions for enhancing robustness with future foundation-model integration.

Abstract

Deep reinforcement learning has achieved remarkable success in learning control policies from pixels across a wide range of tasks, yet its application remains hindered by low sample efficiency, requiring significantly more environment interactions than humans to reach comparable performance. Model-based reinforcement learning (MBRL) offers a solution by leveraging learnt world models to generate simulated experience, thereby improving sample efficiency. However, in visually complex environments, small or dynamic elements can be critical for decision-making. Yet, traditional MBRL methods in pixel-based environments typically rely on auto-encoding with an $L_2$ loss, which is dominated by large areas and often fails to capture decision-relevant details. To address these limitations, we propose an object-centric MBRL pipeline, which integrates recent advances in computer vision to allow agents to focus on key decision-related elements. Our approach consists of four main steps: (1) annotating key objects related to rewards and goals with segmentation masks, (2) extracting object features using a pre-trained, frozen foundation vision model, (3) incorporating these object features with the raw observations to predict environmental dynamics, and (4) training the policy using imagined trajectories generated by this object-centric world model. Building on the efficient MBRL algorithm STORM, we call this pipeline OC-STORM. We demonstrate OC-STORM's practical value in overcoming the limitations of conventional MBRL approaches on both Atari games and the visually complex game Hollow Knight.

Objects matter: object-centric world models improve reinforcement learning in visually complex environments

TL;DR

This paper tackles the problem of sample-inefficient deep reinforcement learning in visually complex environments by proposing OC-STORM, an object-centric MBRL pipeline. It integrates a frozen vision model (Cutie) to extract compact object features, combines these with raw observations via a spatial-temporal transformer, and trains policies on imagined trajectories generated by a world model that uses a categorically latent representation. OC-STORM demonstrates improved sample efficiency over STORM on Atari 100k (18/26 tasks) and achieves faster convergence and stronger performance on Hollow Knight bosses, highlighting practical gains in visually rich settings. The work also analyzes the completeness and usefulness of object representations and discusses limitations related to duplicate objects and background information, suggesting directions for enhancing robustness with future foundation-model integration.

Abstract

Deep reinforcement learning has achieved remarkable success in learning control policies from pixels across a wide range of tasks, yet its application remains hindered by low sample efficiency, requiring significantly more environment interactions than humans to reach comparable performance. Model-based reinforcement learning (MBRL) offers a solution by leveraging learnt world models to generate simulated experience, thereby improving sample efficiency. However, in visually complex environments, small or dynamic elements can be critical for decision-making. Yet, traditional MBRL methods in pixel-based environments typically rely on auto-encoding with an loss, which is dominated by large areas and often fails to capture decision-relevant details. To address these limitations, we propose an object-centric MBRL pipeline, which integrates recent advances in computer vision to allow agents to focus on key decision-related elements. Our approach consists of four main steps: (1) annotating key objects related to rewards and goals with segmentation masks, (2) extracting object features using a pre-trained, frozen foundation vision model, (3) incorporating these object features with the raw observations to predict environmental dynamics, and (4) training the policy using imagined trajectories generated by this object-centric world model. Building on the efficient MBRL algorithm STORM, we call this pipeline OC-STORM. We demonstrate OC-STORM's practical value in overcoming the limitations of conventional MBRL approaches on both Atari games and the visually complex game Hollow Knight.

Paper Structure

This paper contains 44 sections, 14 equations, 18 figures, 4 tables.

Figures (18)

  • Figure 1: A simplified illustration of the object transformer in Cutie. For technical details, please refer to the original paper cheng_putting_2023. The tuples in square brackets represent the shapes of the corresponding tensors.
  • Figure 2: The model structure of our proposed OC-STORM. The tuples in square brackets represent the shapes of the corresponding tensors, where $L$ denotes the batch length or sequence length, $K$ is the number of objects, and $H$ and $W$ are the image height and width, respectively. The object module constitutes the proposed object-centric component, while the visual module processes resized raw observations. $K^*$ is explained in Section \ref{['sec:spatial-temporal']}. The trainable token and positional embeddings are broadcasted to match the shapes of the corresponding tensors. The reward logit is 255-dimensional and used for the symlog two-hot loss hafner_mastering_2023.
  • Figure 3: Observation reconstructions on Atari Boxing with two object feature vectors as inputs. The object mask row is generated using Cutie, which highlights the relevant objects.
  • Figure 4: Training episode returns for different input module configurations. We use a solid line to represent the mean of 5 seeds and use a semi-transparent background to represent the standard deviation. "Vector" and "visual" correspond to the object module and visual module, respectively, as depicted in Figure\ref{['fig:model-structure']}.
  • Figure 5: Sample ground truth observations from the Hollow Knight boss Hornet Protector, the reconstruction results of STORM, and the probabilities of the $32 \times 32$ latent distribution. In this instance, STORM was trained on 200k samples. The key characters are missing in the reconstructions.
  • ...and 13 more figures