Neural Relational Inference for Interacting Systems
Thomas Kipf, Ethan Fetaya, Kuan-Chieh Wang, Max Welling, Richard Zemel
TL;DR
Neural Relational Inference (NRI) tackles the problem of uncovering latent interaction graphs among multiple agents while learning their collective dynamics from unlabeled trajectories. It casts the task as a variational autoencoder with a discrete latent graph $\mathbf{z}$ representing edge types, using a graph neural network–based encoder to infer $\mathbf{z}$ and a per-edge-type GNN decoder to forecast future states. The model employs a continuous relaxation via the Concrete distribution for differentiable learning and uses strategies to avoid degenerate decoders, including multi-step prediction and edge-type–specific message passing, with a recurrent variant for non-Markovian dynamics. Empirically, NRI achieves near-ground-truth recovery of interaction graphs in physics simulations and delivers interpretable, accurate predictions on motion capture and NBA data, demonstrating the method’s applicability to real-world multi-agent systems. The work provides a principled framework for unsupervised relational reasoning with explicit latent graphs and introduces techniques such as dynamic graph re-evaluation to handle changing interactions over time.
Abstract
Interacting systems are prevalent in nature, from dynamical systems in physics to complex societal dynamics. The interplay of components can give rise to complex behavior, which can often be explained using a simple model of the system's constituent parts. In this work, we introduce the neural relational inference (NRI) model: an unsupervised model that learns to infer interactions while simultaneously learning the dynamics purely from observational data. Our model takes the form of a variational auto-encoder, in which the latent code represents the underlying interaction graph and the reconstruction is based on graph neural networks. In experiments on simulated physical systems, we show that our NRI model can accurately recover ground-truth interactions in an unsupervised manner. We further demonstrate that we can find an interpretable structure and predict complex dynamics in real motion capture and sports tracking data.
