Deep Reinforcement Learning via Object-Centric Attention
Jannis Blüml, Cedric Derstroff, Bjarne Gregori, Elisabeth Dillies, Quentin Delfosse, Kristian Kersting
TL;DR
The paper tackles the limited generalization of deep RL agents trained on raw pixels by introducing Object-Centric Attention via Masking (OCCAM), which masks background pixels to preserve task-relevant objects. It leverages simple object bounding boxes to create four abstraction levels (Object, Binary, Class, Planes) and evaluates them on Atari with perturbations to test robustness. Empirical results show OCCAM can match or exceed pixel-based PPO performance and substantially improve resilience to visual perturbations, without requiring domain-specific object pipelines or symbolic reasoning. The findings suggest that structured, object-centric representations can enhance generalization and sample efficiency in RL, offering a practical alternative to fully symbolic or heavily preprocessed approaches.
Abstract
Deep reinforcement learning agents, trained on raw pixel inputs, often fail to generalize beyond their training environments, relying on spurious correlations and irrelevant background details. To address this issue, object-centric agents have recently emerged. However, they require different representations tailored to the task specifications. Contrary to deep agents, no single object-centric architecture can be applied to any environment. Inspired by principles of cognitive science and Occam's Razor, we introduce Object-Centric Attention via Masking (OCCAM), which selectively preserves task-relevant entities while filtering out irrelevant visual information. Specifically, OCCAM takes advantage of the object-centric inductive bias. Empirical evaluations on Atari benchmarks demonstrate that OCCAM significantly improves robustness to novel perturbations and reduces sample complexity while showing similar or improved performance compared to conventional pixel-based RL. These results suggest that structured abstraction can enhance generalization without requiring explicit symbolic representations or domain-specific object extraction pipelines.
