Table of Contents
Fetching ...

Neural Relational Inference for Interacting Systems

Thomas Kipf, Ethan Fetaya, Kuan-Chieh Wang, Max Welling, Richard Zemel

TL;DR

Neural Relational Inference (NRI) tackles the problem of uncovering latent interaction graphs among multiple agents while learning their collective dynamics from unlabeled trajectories. It casts the task as a variational autoencoder with a discrete latent graph $\mathbf{z}$ representing edge types, using a graph neural network–based encoder to infer $\mathbf{z}$ and a per-edge-type GNN decoder to forecast future states. The model employs a continuous relaxation via the Concrete distribution for differentiable learning and uses strategies to avoid degenerate decoders, including multi-step prediction and edge-type–specific message passing, with a recurrent variant for non-Markovian dynamics. Empirically, NRI achieves near-ground-truth recovery of interaction graphs in physics simulations and delivers interpretable, accurate predictions on motion capture and NBA data, demonstrating the method’s applicability to real-world multi-agent systems. The work provides a principled framework for unsupervised relational reasoning with explicit latent graphs and introduces techniques such as dynamic graph re-evaluation to handle changing interactions over time.

Abstract

Interacting systems are prevalent in nature, from dynamical systems in physics to complex societal dynamics. The interplay of components can give rise to complex behavior, which can often be explained using a simple model of the system's constituent parts. In this work, we introduce the neural relational inference (NRI) model: an unsupervised model that learns to infer interactions while simultaneously learning the dynamics purely from observational data. Our model takes the form of a variational auto-encoder, in which the latent code represents the underlying interaction graph and the reconstruction is based on graph neural networks. In experiments on simulated physical systems, we show that our NRI model can accurately recover ground-truth interactions in an unsupervised manner. We further demonstrate that we can find an interpretable structure and predict complex dynamics in real motion capture and sports tracking data.

Neural Relational Inference for Interacting Systems

TL;DR

Neural Relational Inference (NRI) tackles the problem of uncovering latent interaction graphs among multiple agents while learning their collective dynamics from unlabeled trajectories. It casts the task as a variational autoencoder with a discrete latent graph representing edge types, using a graph neural network–based encoder to infer and a per-edge-type GNN decoder to forecast future states. The model employs a continuous relaxation via the Concrete distribution for differentiable learning and uses strategies to avoid degenerate decoders, including multi-step prediction and edge-type–specific message passing, with a recurrent variant for non-Markovian dynamics. Empirically, NRI achieves near-ground-truth recovery of interaction graphs in physics simulations and delivers interpretable, accurate predictions on motion capture and NBA data, demonstrating the method’s applicability to real-world multi-agent systems. The work provides a principled framework for unsupervised relational reasoning with explicit latent graphs and introduces techniques such as dynamic graph re-evaluation to handle changing interactions over time.

Abstract

Interacting systems are prevalent in nature, from dynamical systems in physics to complex societal dynamics. The interplay of components can give rise to complex behavior, which can often be explained using a simple model of the system's constituent parts. In this work, we introduce the neural relational inference (NRI) model: an unsupervised model that learns to infer interactions while simultaneously learning the dynamics purely from observational data. Our model takes the form of a variational auto-encoder, in which the latent code represents the underlying interaction graph and the reconstruction is based on graph neural networks. In experiments on simulated physical systems, we show that our NRI model can accurately recover ground-truth interactions in an unsupervised manner. We further demonstrate that we can find an interpretable structure and predict complex dynamics in real motion capture and sports tracking data.

Paper Structure

This paper contains 39 sections, 13 equations, 17 figures, 2 tables.

Figures (17)

  • Figure 1: Physical simulation of 2D particles coupled by invisible springs (left) according to a latent interaction graph (right). In this example, solid lines between two particle nodes denote connections via springs whereas dashed lines denote the absence of a coupling. In general, multiple, directed edge types -- each with a different associated relation -- are possible.
  • Figure 2: Node-to-edge ($v{\rightarrow}e$) and edge-to-node ($e{\rightarrow}v$) operations for moving between node and edge representations in a GNN. $v{\rightarrow}e$ represents concatenation of node embeddings connected by an edge, whereas $e{\rightarrow}v$ denotes the aggregation of edge embeddings from all incoming edges. In our notation in Eqs. \ref{['eq:gnn_simple1']}--\ref{['eq:gnn_simple2']}, every such operation is followed by a small neural network (e.g. a 2-layer MLP), here denoted by a black arrow. For clarity, we highlight which node embeddings are combined to form a specific edge embedding ($v{\rightarrow}e$) and which edge embeddings are aggregated to a specific node embedding ($e{\rightarrow}v$).
  • Figure 3: The NRI model consists of two jointly trained parts: An encoder that predicts a probability distribution $q_{\phi}(\mathbf{z}|\mathbf{x})$ over the latent interactions given input trajectories; and a decoder that generates trajectory predictions conditioned on both the latent code of the encoder and the previous time step of the trajectory. The encoder takes the form of a GNN with multiple rounds of node-to-edge ($v{\rightarrow}e$) and edge-to-node ($e{\rightarrow}v$) message passing, whereas the decoder runs multiple GNNs in parallel, one for each edge type supplied by the latent code of the encoder $q_{\phi}(\mathbf{z}|\mathbf{x})$.
  • Figure 4: Examples of trajectories used in our experiments from simulations of particles connected by springs (left), charged particles (middle), and phase-coupled oscillators (right).
  • Figure 5: Trajectory predictions from a trained NRI model (unsupervised). Semi-transparent paths denote the first 49 time steps of ground-truth input to the model, from which the interaction graph is estimated. Solid paths denote self-conditioned model predictions.
  • ...and 12 more figures