SPARTAN: A Sparse Transformer World Model Attending to What Matters
Anson Lei, Bernhard Schölkopf, Ingmar Posner
TL;DR
SPARTAN tackles the challenge of learning robust, adaptable world models by inducing local, causal graphs among interacting objects. It introduces a sparsity-regularised, hard-attention Transformer that jointly discovers a context-dependent local graph and predicts futures, while tracking multi-hop dependencies with a path matrix. The approach accommodates sparse interventions via learnable intervention tokens and supports test-time adaptation to unseen dynamics. Empirically, SPARTAN achieves competitive predictive accuracy, lower SHD than baselines, and improved robustness and few-shot adaptation across simulated and real-world-like domains such as Interventional Pong, CREATE, and Waymo traffic.
Abstract
Capturing the interactions between entities in a structured way plays a central role in world models that flexibly adapt to changes in the environment. Recent works motivate the benefits of models that explicitly represent the structure of interactions and formulate the problem as discovering local causal structures. In this work, we demonstrate that reliably capturing these relationships in complex settings remains challenging. To remedy this shortcoming, we postulate that sparsity is a critical ingredient for the discovery of such local structures. To this end, we present the SPARse TrANsformer World model (SPARTAN), a Transformer-based world model that learns context-dependent interaction structures between entities in a scene. By applying sparsity regularisation on the attention patterns between object-factored tokens, SPARTAN learns sparse, context-dependent interaction graphs that accurately predict future object states. We further extend our model to adapt to sparse interventions with unknown targets in the dynamics of the environment. This results in a highly interpretable world model that can efficiently adapt to changes. Empirically, we evaluate SPARTAN against the current state-of-the-art in object-centric world models in observation-based environments and demonstrate that our model can learn local causal graphs that accurately reflect the underlying interactions between objects, achieving significantly improved few-shot adaptation to dynamics changes, as well as robustness against distractors.
