Table of Contents
Fetching ...

Attention for Causal Relationship Discovery from Biological Neural Dynamics

Ziyu Lu, Anika Tabassum, Shruti Kulkarni, Lu Mi, J. Nathan Kutz, Eric Shea-Brown, Seung-Hwan Lim

TL;DR

The paper addresses learning Granger causality in nonlinear, dynamic neural networks by introducing Causalformer, a transformer-based model whose decoder cross attention encodes directed interactions among neurons during multivariate time-series forecasting. It demonstrates, on synthetic data generated with the Izhikevich model, that the cross-attention-derived connectivity can match or surpass traditional Multivariate Granger Causality (MVGC) methods, particularly as network size grows. The approach provides a scalable, nonlinear alternative for causal representation learning in neuroscience and offers a framework for interpreting inter-neuronal influences from large-scale neural population recordings. While promising, the work notes limitations related to real-world data non-stationarity, partial observability, and identifiability, outlining avenues for future enhancement such as handling spike-train data and broader methodological benchmarks.

Abstract

This paper explores the potential of the transformer models for learning Granger causality in networks with complex nonlinear dynamics at every node, as in neurobiological and biophysical networks. Our study primarily focuses on a proof-of-concept investigation based on simulated neural dynamics, for which the ground-truth causality is known through the underlying connectivity matrix. For transformer models trained to forecast neuronal population dynamics, we show that the cross attention module effectively captures the causal relationship among neurons, with an accuracy equal or superior to that for the most popular Granger causality analysis method. While we acknowledge that real-world neurobiology data will bring further challenges, including dynamic connectivity and unobserved variability, this research offers an encouraging preliminary glimpse into the utility of the transformer model for causal representation learning in neuroscience.

Attention for Causal Relationship Discovery from Biological Neural Dynamics

TL;DR

The paper addresses learning Granger causality in nonlinear, dynamic neural networks by introducing Causalformer, a transformer-based model whose decoder cross attention encodes directed interactions among neurons during multivariate time-series forecasting. It demonstrates, on synthetic data generated with the Izhikevich model, that the cross-attention-derived connectivity can match or surpass traditional Multivariate Granger Causality (MVGC) methods, particularly as network size grows. The approach provides a scalable, nonlinear alternative for causal representation learning in neuroscience and offers a framework for interpreting inter-neuronal influences from large-scale neural population recordings. While promising, the work notes limitations related to real-world data non-stationarity, partial observability, and identifiability, outlining avenues for future enhancement such as handling spike-train data and broader methodological benchmarks.

Abstract

This paper explores the potential of the transformer models for learning Granger causality in networks with complex nonlinear dynamics at every node, as in neurobiological and biophysical networks. Our study primarily focuses on a proof-of-concept investigation based on simulated neural dynamics, for which the ground-truth causality is known through the underlying connectivity matrix. For transformer models trained to forecast neuronal population dynamics, we show that the cross attention module effectively captures the causal relationship among neurons, with an accuracy equal or superior to that for the most popular Granger causality analysis method. While we acknowledge that real-world neurobiology data will bring further challenges, including dynamic connectivity and unobserved variability, this research offers an encouraging preliminary glimpse into the utility of the transformer model for causal representation learning in neuroscience.
Paper Structure (14 sections, 5 equations, 4 figures, 2 tables)

This paper contains 14 sections, 5 equations, 4 figures, 2 tables.

Figures (4)

  • Figure 1: Experiment overview. (a) Example five-neuron network connectivity structure. Red arrows represent excitatory connections, and gray arrows represent inhibitory connections. (b) Example membrane potential of the neuronal population in (a), simulated with the Izhikevich model. Firing threshold is marked by dashed black horizontal lines. (c) Illustration of the encoder-decoder architecture of Causalformer. (d) Example decoder cross attention matrix from Causalformer trained on the membrane potential of (a). (e) Ground-truth connectivity matrix for (a). Entry $(i, j)$ is 1 if there is a connection from neuron $j$ to $i$ (regardless of inhibitory or excitatory), and 0 otherwise.
  • Figure 2: Results with random networks of 4 different sizes $N=5, 10, 20, 40$ and 4 different connectivity probabilities $p=0.2, 0.4, 0.6, 0.8$. 10 different typologies were simulated for each $N$ and $p$. Violin plots show the empirical distributions estimated from the 10 random networks. White dots mark the medians of the distributions, and black dots mark the 25th to the 75th percentile. Y-axis shows AUROC (between 0 and 1, higher is better). X-axis shows the methods used for inferring causality. MVGC (AIC), MVGC (BIC), Causalformer stand for Multivariate Granger Causality with VAR model order selected by AIC and BIC, and Causalformer decoder cross attention, respectively.
  • Figure 3: Causalformer architecture. In encoder local self attention, each neuron can only attend to the embedding of its own history. Similarly, in decoder local cross attention, the target embedding of each neuron can only attend to the encoder representation of its own history. In decoder global cross attention, the target embedding of each neuron can attend to the encoded history representations of all neurons, including itself. The distinction between local and global attentions was first introduced in grigsby2021long as an architectural bias. Causal relationships between neurons are inferred from the decoder's global cross attention module.
  • Figure 4: Causalformer prediction quality. (a) Example true and predicted normalized membrane potential traces. Since we make one-step prediction in each test sample, predicted traces are made by concatenating predictions across test samples. Test $R^2$ scores are computed between the true and predicted traces on the test set. (b) Distribution of test $R^2$ scores (averaged across all neurons) achieved by all Causalformer models trained on the random networks. Note that test $R^2$ scores are concentrated around 0.935, similar to those in (a), which means that the vast majority of the models are able to capture the dynamics reasonably well, based on the visualization in (a)