Table of Contents
Fetching ...

Slot Structured World Models

Jonathan Collu, Riccardo Majellaro, Aske Plaat, Thomas M. Moerland

TL;DR

Slot Structured World Models (SSWM) address the need for robust object-centric scene understanding by embedding an object-focused Slot Attention encoder within a graph-based latent dynamics model. The approach yields more disentangled object representations and improved multi-step prediction over the prior C-SWM baseline, especially in environments requiring complex object interactions. Empirical results on the Interactive Spriteworld benchmark show that SSWM achieves higher accuracy across 1, 5, and 10-step horizons and provides clearer latent-space structure and pixel-space predictions. These findings suggest object-centric, slot-based representations paired with iterative graph dynamics can enhance model-based reasoning and planning in environments with multiple interacting objects; code is available for reproduction.

Abstract

The ability to perceive and reason about individual objects and their interactions is a goal to be achieved for building intelligent artificial systems. State-of-the-art approaches use a feedforward encoder to extract object embeddings and a latent graph neural network to model the interaction between these object embeddings. However, the feedforward encoder can not extract {\it object-centric} representations, nor can it disentangle multiple objects with similar appearance. To solve these issues, we introduce {\it Slot Structured World Models} (SSWM), a class of world models that combines an {\it object-centric} encoder (based on Slot Attention) with a latent graph-based dynamics model. We evaluate our method in the Spriteworld benchmark with simple rules of physical interaction, where Slot Structured World Models consistently outperform baselines on a range of (multi-step) prediction tasks with action-conditional object interactions. All code to reproduce paper experiments is available from \url{https://github.com/JonathanCollu/Slot-Structured-World-Models}.

Slot Structured World Models

TL;DR

Slot Structured World Models (SSWM) address the need for robust object-centric scene understanding by embedding an object-focused Slot Attention encoder within a graph-based latent dynamics model. The approach yields more disentangled object representations and improved multi-step prediction over the prior C-SWM baseline, especially in environments requiring complex object interactions. Empirical results on the Interactive Spriteworld benchmark show that SSWM achieves higher accuracy across 1, 5, and 10-step horizons and provides clearer latent-space structure and pixel-space predictions. These findings suggest object-centric, slot-based representations paired with iterative graph dynamics can enhance model-based reasoning and planning in environments with multiple interacting objects; code is available for reproduction.

Abstract

The ability to perceive and reason about individual objects and their interactions is a goal to be achieved for building intelligent artificial systems. State-of-the-art approaches use a feedforward encoder to extract object embeddings and a latent graph neural network to model the interaction between these object embeddings. However, the feedforward encoder can not extract {\it object-centric} representations, nor can it disentangle multiple objects with similar appearance. To solve these issues, we introduce {\it Slot Structured World Models} (SSWM), a class of world models that combines an {\it object-centric} encoder (based on Slot Attention) with a latent graph-based dynamics model. We evaluate our method in the Spriteworld benchmark with simple rules of physical interaction, where Slot Structured World Models consistently outperform baselines on a range of (multi-step) prediction tasks with action-conditional object interactions. All code to reproduce paper experiments is available from \url{https://github.com/JonathanCollu/Slot-Structured-World-Models}.
Paper Structure (24 sections, 6 equations, 7 figures, 4 tables, 1 algorithm)

This paper contains 24 sections, 6 equations, 7 figures, 4 tables, 1 algorithm.

Figures (7)

  • Figure 1: Comparison of object masks learned by C-SWM (baseline, top row) cswm on Shapes 2D and SSWM (our method, bottom row) on Interactive Spriteworld. The left image of each row shows the input image, while the five images next to it show the learned object masks (encoder output) that enter the GNN. Top: The input image contains two duplicate objects (red circles and red triangles), which the C-SWM encoder cannot disentangle and instead conflates in a single embedding (for example, the two red circles end up in both the first and fourth embedding). These embeddings will make modeling of pairwise object interactions in the downstream GNN less effective. Bottom: In contrast, our SSWM method can disentangle the duplicate items in the input image (orange circles) into distinct slots, due to the object-centric competitive attention mechanism in the Slot Attention encoder. This allows for more effective modeling of object interaction in the downstream GNN, as we show in the Results section.
  • Figure 2: Architectural design of Slot Structured World Models. Given an input image (left), we use a pretrained Slot Attention encoder to produce a set of object-centric embeddings ($z_t$) that capture the objects (and background) in the scene. These latent embeddings are then fed together with the chosen action into an iterative GNN transition module (defined in Algorithm \ref{['alg:gnn']}) to predict the change in the next latent state ($\Delta z_t$). The graph-based transition model is trained to minimize the (object-wise) L2-norm between the latent prediction ($z_t + \Delta z_t)$ and the Slot Attention embedding of the true next state ($z_{t+1}$).
  • Figure 3: Example states from Interactive Spriteworld. The top row shows five samples from different episodes. The bottom row shows five consecutive timesteps of a particular episode where the agent (white circle) keeps moving down, carrying all objects in the scene down as well (since they push against each other).
  • Figure 4: Pixel space decoding of the learned latent space predictions of SSWM. The top-left 3x4 images show predictions when only the agent/sprite is moving, the top-right 3x4 images show predictions when the agent carries one object, and the bottom 3x4 images show predictions when the agent moves multiple objects (that push against each other). For each 3x4 block, the three rows show 1, 5, and 10-step predictions, respectively. The four images in each of these rows show the true source state ('State'), the decoded state predicted by SSWM ('Pred. Next State'), the true state reached after taking the actions ('Next State'), and the prediction error between the latter two ('Next state - Pred'), where full black indicates no error.
  • Figure 5: Visualization of the object masks produced by the C-SWM object extractor, for the input images labeled as 'State' (left). We have six slots that each contain four feature maps, which we display above each other. Ideally, each slot uniquely identifies an object and represents its shape, which we need in the downstream prediction task. Instead, we see the model cannot detect exact shapes, nor can it isolate objects per slot (in contrast to the Slot Attention encoder of SSWM, as shown in Fig. \ref{['instancedisambiguation']}). This suggests C-SWM learned a solution that does optimize the contrastive objective of \ref{['equation:cswmloss']}, but does not encode all relevant information.
  • ...and 2 more figures