Table of Contents
Fetching ...

Graph neural networks uncover structure and functions underlying the activity of simulated neural assemblies

Cédric Allier, Larissa Heinrich, Magdalena Schneider, Stephan Saalfeld

TL;DR

The paper addresses the challenge of interpreting complex, heterogeneous neural dynamics by learning a mechanistic, interpretable model rather than solely optimizing predictive accuracy. It introduces a message-passing graph neural network that maintains a low-dimensional latent embedding $\boldsymbol{a}_i$ for each neuron, jointly learns a connectivity matrix $\boldsymbol{W}$ and neuron-specific functions $\phi^*$ and $\psi^*$, and infers external inputs via $\Omega^*(t)$. The approach yields highly accurate rollouts and recovers key components of the underlying system: the wiring $\boldsymbol{W}$, neuron types from latent clusters, signaling functions, and external stimuli; symbolic regression recovers analytical forms of the learned functions. This framework demonstrates robust performance across large networks (up to $N=8000$) and varying architectures, offering a scalable, interpretable pathway to mechanistic modeling of neural activity with potential applicability to experimental data and transfer across configurations.

Abstract

Graph neural networks trained to predict observable dynamics can be used to decompose the temporal activity of complex heterogeneous systems into simple, interpretable representations. Here we apply this framework to simulated neural assemblies with thousands of neurons and demonstrate that it can jointly reveal the connectivity matrix, the neuron types, the signaling functions, and in some cases hidden external stimuli. In contrast to existing machine learning approaches such as recurrent neural networks and transformers, which emphasize predictive accuracy but offer limited interpretability, our method provides both reliable forecasts of neural activity and interpretable decomposition of the mechanisms governing large neural assemblies.

Graph neural networks uncover structure and functions underlying the activity of simulated neural assemblies

TL;DR

The paper addresses the challenge of interpreting complex, heterogeneous neural dynamics by learning a mechanistic, interpretable model rather than solely optimizing predictive accuracy. It introduces a message-passing graph neural network that maintains a low-dimensional latent embedding for each neuron, jointly learns a connectivity matrix and neuron-specific functions and , and infers external inputs via . The approach yields highly accurate rollouts and recovers key components of the underlying system: the wiring , neuron types from latent clusters, signaling functions, and external stimuli; symbolic regression recovers analytical forms of the learned functions. This framework demonstrates robust performance across large networks (up to ) and varying architectures, offering a scalable, interpretable pathway to mechanistic modeling of neural activity with potential applicability to experimental data and transfer across configurations.

Abstract

Graph neural networks trained to predict observable dynamics can be used to decompose the temporal activity of complex heterogeneous systems into simple, interpretable representations. Here we apply this framework to simulated neural assemblies with thousands of neurons and demonstrate that it can jointly reveal the connectivity matrix, the neuron types, the signaling functions, and in some cases hidden external stimuli. In contrast to existing machine learning approaches such as recurrent neural networks and transformers, which emphasize predictive accuracy but offer limited interpretability, our method provides both reliable forecasts of neural activity and interpretable decomposition of the mechanisms governing large neural assemblies.
Paper Structure (8 sections, 7 equations, 18 figures, 5 tables)

This paper contains 8 sections, 7 equations, 18 figures, 5 tables.

Figures (18)

  • Figure 1: The temporal activity of a simulated neural network (a) is converted into densely connected graph (b) processed by a message passing GNN (c). Each neuron (node $i$) receives activity signals $x_j$ from connected neurons (node $j$), processed by a transfer function $\psi^*$ and weighted by the matrix $\boldsymbol{W}$. The sum of these messages is updated with functions $\phi^*$ and $\Omega^*$ to obtain the predicted activity rate $\widehat{\dot{\boldsymbol{x}}}_{i}$. In addition to the observed activity $x_i$, the GNN has access to learnable latent vectors $\boldsymbol{a}_i$ associated with each node $i$.
  • Figure 1: 1000 densely connected neurons with 4 neuron-dependent update functions. Results plotted over 20 epochs. (a) Learned latent vectors $\boldsymbol{a}_i$ of all neurons. (b) Learned update functions $\phi^*(\boldsymbol{a}, x)$. (c) Learned transfer function $\psi^*(x)$, normalized to a maximum value of 1. (d) Learned connectivity $\boldsymbol{W}_{ij}$. (e) Comparison of learned and true connectivity. Colors indicate true neuron types.
  • Figure 2: 1000 densely connected neurons with 4 neuron-dependent update functions. (a) Activity time series used for GNN training. This dataset contains $10^5$ time-points. (b) Sample of 10 time series taken from (a). (c) True connectivity $\boldsymbol{W}_{ij}$. The inset shows $20\times20$ weights. (d) Learned connectivity. (e) Comparison of learned and true connectivity (given $g_i=10$ in \ref{['eqn:simulation']}). (f) Learned latent vectors $\boldsymbol{a}_i$ of all neurons. (g) Learned update functions $\phi^*(\boldsymbol{a_i}, x_i)$. The plot shows 1000 overlaid curves, one for each vector $\boldsymbol{a_i}$ (h) Learned transfer function $\psi^*(x_i)$, normalized to a maximum value of 1. Colors indicate true neuron types. True functions are overlaid in light gray.
  • Figure 2: Rollout inference performed with the GNN model trained with a simulation of 1000 densely connected neurons (\ref{['fig2']}). Results plotted at time step 400 and 800, respectively. (a) and (c) 25 learned activity traces plotted as a function of time-points. True activity traces are overlaid in light gray. (b) and (d) Comparison between true and learned activity values of 1000 neurons.
  • Figure 3: 2048 densely connected neurons with different neuron-dependent update and transfer functions (4 neuron types), in the presence of external inputs. The training dataset contains $10^5$ time points. (a) External inputs are represented by a time-dependent scalar field $\Omega_i(t)$ that scales the connectivity matrix $\boldsymbol{W}_{ij}$ (\ref{['eqn:simulation2']}). 1024 neurons (left), spatially ordered, are modulated by this field. The other 1024 neurons (right) are not affected ($\Omega_i=1$). (b) Activity time values. (c) Sample of 10 time series used for training. (d) Comparison of learned and true connectivity $\boldsymbol{W}_{ij}$ (given $g_i=10$ in \ref{['eqn:simulation']}). (e) Comparison of learned and true $\Omega_i(t)$ values. (f) True field $\Omega_i(t)$ plotted at different time-points. (g) Learned field $\Omega_i^*(t)$ plotted at different time-points.
  • ...and 13 more figures