Inferring stochastic low-rank recurrent neural networks from neural data
Matthijs Pals, A Erdem Sağtekin, Felix Pei, Manuel Gloeckler, Jakob H Macke
TL;DR
The paper addresses inferring stochastic, low-rank recurrent neural networks (RNNs) from noisy neural data by formulating the dynamics in a reduced $R$-dimensional latent space using a rank-$R$ connectivity $ extbf{J}= extbf{M} extbf{N}^{\top}$ and mapping back to observed activity. It introduces variational sequential Monte Carlo (SMC) to fit these models, including an encoder-based proposal for nonlinear observations and a Generalised Teacher Forcing mechanism to stabilize temporal inference. Empirically, the approach yields lower-dimensional latent dynamics across EEG, hippocampal spiking, and monkey-reaching data while providing a tractable fixed-point analysis for piecewise-linear activations; it also demonstrates favorable comparisons to state-of-the-art methods in reconstruction quality with smaller latent dimensionality. The work thus offers a principled, generative framework that captures trial-to-trial neural variability with interpretable, analytically tractable dynamics, and it outlines concrete directions for extending to broader noise models and multi-modal neural data.
Abstract
A central aim in computational neuroscience is to relate the activity of large populations of neurons to an underlying dynamical system. Models of these neural dynamics should ideally be both interpretable and fit the observed data well. Low-rank recurrent neural networks (RNNs) exhibit such interpretability by having tractable dynamics. However, it is unclear how to best fit low-rank RNNs to data consisting of noisy observations of an underlying stochastic system. Here, we propose to fit stochastic low-rank RNNs with variational sequential Monte Carlo methods. We validate our method on several datasets consisting of both continuous and spiking neural data, where we obtain lower dimensional latent dynamics than current state of the art methods. Additionally, for low-rank models with piecewise linear nonlinearities, we show how to efficiently identify all fixed points in polynomial rather than exponential cost in the number of units, making analysis of the inferred dynamics tractable for large RNNs. Our method both elucidates the dynamical systems underlying experimental recordings and provides a generative model whose trajectories match observed variability.
