Table of Contents
Fetching ...

Object-Centric World Models for Causality-Aware Reinforcement Learning

Yosuke Nishimoto, Takashi Matsubara

TL;DR

<3-5 sentence high-level summary>: STICA addresses the challenge of learning effective policies in high-dimensional, multi-object environments by combining an object-centric world model with Transformer dynamics and causality-aware policy/value networks. It decomposes observations into object slots, uses a transformer-based dynamics model to imagine future trajectories, and employs a causality graph to modulate attention toward task-relevant objects. Ablation and visualization analyses demonstrate that causal attention and background removal are key to focusing on pertinent objects, yielding substantial gains on Safety Gym and OC-VRL benchmarks. This work advances model-based RL with interpretable, object-centered representations and causality-guided decision-making, with strong implications for robotics and real-world planning under uncertainty.

Abstract

World models have been developed to support sample-efficient deep reinforcement learning agents. However, it remains challenging for world models to accurately replicate environments that are high-dimensional, non-stationary, and composed of multiple objects with rich interactions since most world models learn holistic representations of all environmental components. By contrast, humans perceive the environment by decomposing it into discrete objects, facilitating efficient decision-making. Motivated by this insight, we propose \emph{Slot Transformer Imagination with CAusality-aware reinforcement learning} (STICA), a unified framework in which object-centric Transformers serve as the world model and causality-aware policy and value networks. STICA represents each observation as a set of object-centric tokens, together with tokens for the agent action and the resulting reward, enabling the world model to predict token-level dynamics and interactions. The policy and value networks then estimate token-level cause--effect relations and use them in the attention layers, yielding causality-guided decision-making. Experiments on object-rich benchmarks demonstrate that STICA consistently outperforms state-of-the-art agents in both sample efficiency and final performance.

Object-Centric World Models for Causality-Aware Reinforcement Learning

TL;DR

<3-5 sentence high-level summary>: STICA addresses the challenge of learning effective policies in high-dimensional, multi-object environments by combining an object-centric world model with Transformer dynamics and causality-aware policy/value networks. It decomposes observations into object slots, uses a transformer-based dynamics model to imagine future trajectories, and employs a causality graph to modulate attention toward task-relevant objects. Ablation and visualization analyses demonstrate that causal attention and background removal are key to focusing on pertinent objects, yielding substantial gains on Safety Gym and OC-VRL benchmarks. This work advances model-based RL with interpretable, object-centered representations and causality-guided decision-making, with strong implications for robotics and real-world planning under uncertainty.

Abstract

World models have been developed to support sample-efficient deep reinforcement learning agents. However, it remains challenging for world models to accurately replicate environments that are high-dimensional, non-stationary, and composed of multiple objects with rich interactions since most world models learn holistic representations of all environmental components. By contrast, humans perceive the environment by decomposing it into discrete objects, facilitating efficient decision-making. Motivated by this insight, we propose \emph{Slot Transformer Imagination with CAusality-aware reinforcement learning} (STICA), a unified framework in which object-centric Transformers serve as the world model and causality-aware policy and value networks. STICA represents each observation as a set of object-centric tokens, together with tokens for the agent action and the resulting reward, enabling the world model to predict token-level dynamics and interactions. The policy and value networks then estimate token-level cause--effect relations and use them in the attention layers, yielding causality-guided decision-making. Experiments on object-rich benchmarks demonstrate that STICA consistently outperforms state-of-the-art agents in both sample efficiency and final performance.

Paper Structure

This paper contains 35 sections, 8 equations, 21 figures, 4 tables.

Figures (21)

  • Figure 1: The architecture and object-centric representations of STICA. (a) Object-Centric World Model. Slot-based autoencoder extracts object-centric latent states $(z_{t}^1,\dots,z_{t}^n)$ from observation $o_{t}$, excluding static background information $z_{BG}$ at time $t$ ($1\le t \le T)$. Transformer-based dynamics model computes hidden states $(h_{1:t}^1, \dots, h_{1:t}^n)$ and $h_{1:t}'$ from latent states $(z_{1:t}^1, \dots, z_{1:t}^n)$, actions $a_{1:t}$, and rewards $r_{1:t-1}$, followed by the multilayer perceptrons (MLPs) that predict the next latent states $(\hat{z}_{t+1}^1, \dots, \hat{z}_{t+1}^n)$, the reward $\hat{r}_t$, and the discount factor $\hat{\gamma}_t$. (b) Examples of object-centric representations for Safety Gym benchmark task; the observation $o_t$, its reconstruction $\hat{o}_t$, the reconstructions from the extracted object-centric latent states $(z_t^1,\dots,z_t^5)$, and that from the static background information $z_{BG}$. (c) Causal policy and value networks. They estimate causal relationships from the latent states to the action or value, based on a causal graph $G$ and causality scores $p_t^k$, and adjust the attention weights within the Transformers accordingly, enabling the causality-aware decision-making. The latent states ($z_t^1$, $z_t^2$, and $z_t^5$) of goal-related objects or obstacles are expected to have stronger causal influence on the target token ($a'_t$ or $v'_t$), while latent states ($z_t^3$ and $z_t^4$) of objects irrelevant to task completion have weaker causal influence.
  • Figure 2: (Left) Eight tasks from the Safety Gym benchmark. Top panels show first-person views for training. Bottom panels show fixed-view images for reference and not for training. (Right) Three tasks from the OCVRL benchmark.
  • Figure 3: Training curves for Safety Gym benchmark.
  • Figure 4: Training curves for the OCVRL benchmark.
  • Figure 5: Ablation studies on the PointButton1 and CarButton1 from the Safety Gym benchmark.
  • ...and 16 more figures