Graph neural networks uncover structure and functions underlying the activity of simulated neural assemblies
Cédric Allier, Larissa Heinrich, Magdalena Schneider, Stephan Saalfeld
TL;DR
The paper addresses the challenge of interpreting complex, heterogeneous neural dynamics by learning a mechanistic, interpretable model rather than solely optimizing predictive accuracy. It introduces a message-passing graph neural network that maintains a low-dimensional latent embedding $\boldsymbol{a}_i$ for each neuron, jointly learns a connectivity matrix $\boldsymbol{W}$ and neuron-specific functions $\phi^*$ and $\psi^*$, and infers external inputs via $\Omega^*(t)$. The approach yields highly accurate rollouts and recovers key components of the underlying system: the wiring $\boldsymbol{W}$, neuron types from latent clusters, signaling functions, and external stimuli; symbolic regression recovers analytical forms of the learned functions. This framework demonstrates robust performance across large networks (up to $N=8000$) and varying architectures, offering a scalable, interpretable pathway to mechanistic modeling of neural activity with potential applicability to experimental data and transfer across configurations.
Abstract
Graph neural networks trained to predict observable dynamics can be used to decompose the temporal activity of complex heterogeneous systems into simple, interpretable representations. Here we apply this framework to simulated neural assemblies with thousands of neurons and demonstrate that it can jointly reveal the connectivity matrix, the neuron types, the signaling functions, and in some cases hidden external stimuli. In contrast to existing machine learning approaches such as recurrent neural networks and transformers, which emphasize predictive accuracy but offer limited interpretability, our method provides both reliable forecasts of neural activity and interpretable decomposition of the mechanisms governing large neural assemblies.
