Table of Contents
Fetching ...

Collective Relational Inference for learning heterogeneous interactions

Zhichao Han, Olga Fink, David S. Kammer

TL;DR

This work introduces Collective Relational Inference (CRI), a probabilistic framework for learning heterogeneous interactions by inferring the joint distribution of edge-types within local subgraphs, thereby capturing correlations among incoming interactions. Built on a generalized EM algorithm, CRI combines a flexible inference module with a physics-aware generative module to learn interaction laws, and it extends to evolving graphs via Evolving-CRI. Across causality discovery, heterogeneous particle interactions, and crystallization with changing topology, CRI demonstrates superior accuracy, data efficiency, and the ability to impose known constraints (e.g., Newtonian physics) to recover physics-consistent laws. The approach offers strong generalization to larger systems and provides a versatile tool for graph-structure learning and discovery of governing equations in complex multi-agent systems.

Abstract

Interacting systems are ubiquitous in nature and engineering, ranging from particle dynamics in physics to functionally connected brain regions. These interacting systems can be modeled by graphs where edges correspond to the interactions between interactive entities. Revealing interaction laws is of fundamental importance but also particularly challenging due to underlying configurational complexities. The associated challenges become exacerbated for heterogeneous systems that are prevalent in reality, where multiple interaction types coexist simultaneously and relational inference is required. Here, we propose a novel probabilistic method for relational inference, which possesses two distinctive characteristics compared to existing methods. First, it infers the interaction types of different edges collectively by explicitly encoding the correlation among incoming interactions with a joint distribution, and second, it allows handling systems with variable topological structure over time. We evaluate the proposed methodology across several benchmark datasets and demonstrate that it outperforms existing methods in accurately inferring interaction types. We further show that when combined with known constraints, it allows us, for example, to discover physics-consistent interaction laws of particle systems. Overall the proposed model is data-efficient and generalizable to large systems when trained on smaller ones. The developed methodology constitutes a key element for understanding interacting systems and may find application in graph structure learning.

Collective Relational Inference for learning heterogeneous interactions

TL;DR

This work introduces Collective Relational Inference (CRI), a probabilistic framework for learning heterogeneous interactions by inferring the joint distribution of edge-types within local subgraphs, thereby capturing correlations among incoming interactions. Built on a generalized EM algorithm, CRI combines a flexible inference module with a physics-aware generative module to learn interaction laws, and it extends to evolving graphs via Evolving-CRI. Across causality discovery, heterogeneous particle interactions, and crystallization with changing topology, CRI demonstrates superior accuracy, data efficiency, and the ability to impose known constraints (e.g., Newtonian physics) to recover physics-consistent laws. The approach offers strong generalization to larger systems and provides a versatile tool for graph-structure learning and discovery of governing equations in complex multi-agent systems.

Abstract

Interacting systems are ubiquitous in nature and engineering, ranging from particle dynamics in physics to functionally connected brain regions. These interacting systems can be modeled by graphs where edges correspond to the interactions between interactive entities. Revealing interaction laws is of fundamental importance but also particularly challenging due to underlying configurational complexities. The associated challenges become exacerbated for heterogeneous systems that are prevalent in reality, where multiple interaction types coexist simultaneously and relational inference is required. Here, we propose a novel probabilistic method for relational inference, which possesses two distinctive characteristics compared to existing methods. First, it infers the interaction types of different edges collectively by explicitly encoding the correlation among incoming interactions with a joint distribution, and second, it allows handling systems with variable topological structure over time. We evaluate the proposed methodology across several benchmark datasets and demonstrate that it outperforms existing methods in accurately inferring interaction types. We further show that when combined with known constraints, it allows us, for example, to discover physics-consistent interaction laws of particle systems. Overall the proposed model is data-efficient and generalizable to large systems when trained on smaller ones. The developed methodology constitutes a key element for understanding interacting systems and may find application in graph structure learning.
Paper Structure (38 sections, 29 equations, 13 figures, 14 tables, 1 algorithm)

This paper contains 38 sections, 29 equations, 13 figures, 14 tables, 1 algorithm.

Figures (13)

  • Figure 1: Comparison between (A) neural relational inference (NRI) approach kipf2018neural and (B) our proposed method CRI for relational inference. NRI predicts the interaction type of different edges independently (e.g., the incoming edges of $v_1$). CRI takes the subgraph of each node (e.g., $S_{(1)}$) as an entity. We learn the joint distribution of the type for all edges in the subgraph, allowing for modeling their collective influence on node states. The red bars depict a categorical distribution where the length represents the probability of a particular realization. $\mathcal{F}_v$ and $\mathcal{F}_e$ represent the function approximation of the node state update function and interaction function (by neural networks), respectively. Other mathematical symbols are explained in Sec. \ref{['sec:method']} and summarized in the table in SI Sec. \ref{['sec:symbol_table']}.
  • Figure 2: The pipeline of CRI. The probabilistic relational inference module takes the observed states of the interacting system at different time steps and the current estimation of the pairwise interactions as input, and infers the joint distribution of the interaction type of all edges in each subgraph. The generative module takes the predicted joint distribution for each subgraph together with the observations as input and updates the estimation of different interaction functions. It can be any kind of graph neural network, e.g., the standard message-passing GNN used in kipf2018neural or the physics-induced graph neural network han2022learning. The red bars depict a categorical distribution where the length represents the probability of a particular realization.
  • Figure 3: Illustration and results for causality discovery. (A) Causality discovery task outline. The state time series of different entities are observed, but whether a pair of entities has a causality relationship is unknown. The relational inference method is expected to infer the correct directed graph structure that represents the underlying causality relationships. (B) The ground-truth causality graph and (C) prediction accuracy for different datasets. Mean and standard derivation are computed from five independent experiments. For NRI on the VAR-a and VAR-c, we report the best performance among the five experiments because some random seeds lead to severe sub-optimal performance for NRI on these two datasets.
  • Figure 4: Test performances for the spring and charge experiments. Mean and standard derivation are computed from five independent experiments. (left column) Accuracy of the interaction type inference. (center column) MAE of pairwise force. NRI and MPM cannot infer pairwise force. Empty symbols with dashed lines in E2 (Charge N5K2) indicate the range in which NRI-PIG'N'PI and MPM-PIG'N'PI do not learn any useful information about pairwise forces. (right column) MAE of state (position and velocity combined) after 10 simulation steps.
  • Figure 5: Concept of Evolving-CRI to learn the heterogeneous interactions in crystallization problems. (left) System evolution during crystallization. Yellow and red colors indicate two different kinds of particles with heterogeneous interactions. (right) Schematic of Evolving-CRI consisting of an inference module and a generative module. Evolving-CRI is trained to predict the ground-truth acceleration. After training, the heterogeneous interactions are implicitly learnt.
  • ...and 8 more figures