Object-Centric World Models for Causality-Aware Reinforcement Learning
Yosuke Nishimoto, Takashi Matsubara
TL;DR
<3-5 sentence high-level summary>: STICA addresses the challenge of learning effective policies in high-dimensional, multi-object environments by combining an object-centric world model with Transformer dynamics and causality-aware policy/value networks. It decomposes observations into object slots, uses a transformer-based dynamics model to imagine future trajectories, and employs a causality graph to modulate attention toward task-relevant objects. Ablation and visualization analyses demonstrate that causal attention and background removal are key to focusing on pertinent objects, yielding substantial gains on Safety Gym and OC-VRL benchmarks. This work advances model-based RL with interpretable, object-centered representations and causality-guided decision-making, with strong implications for robotics and real-world planning under uncertainty.
Abstract
World models have been developed to support sample-efficient deep reinforcement learning agents. However, it remains challenging for world models to accurately replicate environments that are high-dimensional, non-stationary, and composed of multiple objects with rich interactions since most world models learn holistic representations of all environmental components. By contrast, humans perceive the environment by decomposing it into discrete objects, facilitating efficient decision-making. Motivated by this insight, we propose \emph{Slot Transformer Imagination with CAusality-aware reinforcement learning} (STICA), a unified framework in which object-centric Transformers serve as the world model and causality-aware policy and value networks. STICA represents each observation as a set of object-centric tokens, together with tokens for the agent action and the resulting reward, enabling the world model to predict token-level dynamics and interactions. The policy and value networks then estimate token-level cause--effect relations and use them in the attention layers, yielding causality-guided decision-making. Experiments on object-rich benchmarks demonstrate that STICA consistently outperforms state-of-the-art agents in both sample efficiency and final performance.
