Table of Contents
Fetching ...

SPARTAN: A Sparse Transformer World Model Attending to What Matters

Anson Lei, Bernhard Schölkopf, Ingmar Posner

TL;DR

SPARTAN tackles the challenge of learning robust, adaptable world models by inducing local, causal graphs among interacting objects. It introduces a sparsity-regularised, hard-attention Transformer that jointly discovers a context-dependent local graph and predicts futures, while tracking multi-hop dependencies with a path matrix. The approach accommodates sparse interventions via learnable intervention tokens and supports test-time adaptation to unseen dynamics. Empirically, SPARTAN achieves competitive predictive accuracy, lower SHD than baselines, and improved robustness and few-shot adaptation across simulated and real-world-like domains such as Interventional Pong, CREATE, and Waymo traffic.

Abstract

Capturing the interactions between entities in a structured way plays a central role in world models that flexibly adapt to changes in the environment. Recent works motivate the benefits of models that explicitly represent the structure of interactions and formulate the problem as discovering local causal structures. In this work, we demonstrate that reliably capturing these relationships in complex settings remains challenging. To remedy this shortcoming, we postulate that sparsity is a critical ingredient for the discovery of such local structures. To this end, we present the SPARse TrANsformer World model (SPARTAN), a Transformer-based world model that learns context-dependent interaction structures between entities in a scene. By applying sparsity regularisation on the attention patterns between object-factored tokens, SPARTAN learns sparse, context-dependent interaction graphs that accurately predict future object states. We further extend our model to adapt to sparse interventions with unknown targets in the dynamics of the environment. This results in a highly interpretable world model that can efficiently adapt to changes. Empirically, we evaluate SPARTAN against the current state-of-the-art in object-centric world models in observation-based environments and demonstrate that our model can learn local causal graphs that accurately reflect the underlying interactions between objects, achieving significantly improved few-shot adaptation to dynamics changes, as well as robustness against distractors.

SPARTAN: A Sparse Transformer World Model Attending to What Matters

TL;DR

SPARTAN tackles the challenge of learning robust, adaptable world models by inducing local, causal graphs among interacting objects. It introduces a sparsity-regularised, hard-attention Transformer that jointly discovers a context-dependent local graph and predicts futures, while tracking multi-hop dependencies with a path matrix. The approach accommodates sparse interventions via learnable intervention tokens and supports test-time adaptation to unseen dynamics. Empirically, SPARTAN achieves competitive predictive accuracy, lower SHD than baselines, and improved robustness and few-shot adaptation across simulated and real-world-like domains such as Interventional Pong, CREATE, and Waymo traffic.

Abstract

Capturing the interactions between entities in a structured way plays a central role in world models that flexibly adapt to changes in the environment. Recent works motivate the benefits of models that explicitly represent the structure of interactions and formulate the problem as discovering local causal structures. In this work, we demonstrate that reliably capturing these relationships in complex settings remains challenging. To remedy this shortcoming, we postulate that sparsity is a critical ingredient for the discovery of such local structures. To this end, we present the SPARse TrANsformer World model (SPARTAN), a Transformer-based world model that learns context-dependent interaction structures between entities in a scene. By applying sparsity regularisation on the attention patterns between object-factored tokens, SPARTAN learns sparse, context-dependent interaction graphs that accurately predict future object states. We further extend our model to adapt to sparse interventions with unknown targets in the dynamics of the environment. This results in a highly interpretable world model that can efficiently adapt to changes. Empirically, we evaluate SPARTAN against the current state-of-the-art in object-centric world models in observation-based environments and demonstrate that our model can learn local causal graphs that accurately reflect the underlying interactions between objects, achieving significantly improved few-shot adaptation to dynamics changes, as well as robustness against distractors.

Paper Structure

This paper contains 26 sections, 8 equations, 6 figures, 8 tables.

Figures (6)

  • Figure 1: In the context of modelling physical interactions, a global causal graph is often uninformative and close to fully-connected. A time-dependent local causal graph better captures the sparse nature of interactions between entities. We present SPARTAN, a Transformer-based world model that discovers local causal structure using hard attention with sparsity regularisation.
  • Figure 2: Example rollouts in the two simulated environments with the learnt local causal graph visualised. Blue and red arrows between icons indicate the learnt causal edges and intervention targets respectively. In the Interventional Pong example, the intervention is that the ball slows down in the middle. SPARTAN correctly identifies the same causal dependencies as the ground-truth (e.g. ball causes the paddles to follow). The Transformer baseline learns edges that do not correspond to the ground-truth. Similarly in CREATE, SPARTAN learns the correct causal edges (e.g. green ball bounces off the blue plank) while Transformer learns many spurious interventions.
  • Figure 3: Visualisation of the causal relationships learned by the models compared to human labeled data. The ego vehicle is blue. Transparent rectangles means that the vehicle is not a causal parent of the ego vehicle. SPARTAN learns a similar attention pattern to human data (e.g. focus on vehicles that are moving in the same direction) whereas Transformer learns many spurious edges.
  • Figure 4: Adaptation errors on the two datasets. Each model has access to 5 trajectories with unknown environment index. "Unseen environments" refers to interventions that are not in the training set. Red lines indicates the prediction error if the environment index is provided, which serves as a lower bound. SPARTAN (blue) consistently achieves the lowest errors across environments.
  • Figure 5: Example training curves.
  • ...and 1 more figures