Table of Contents
Fetching ...

Generative Modeling of Neural Dynamics via Latent Stochastic Differential Equations

Ahmed ElGazzar, Marcel van Gerven

TL;DR

The paper addresses the challenge of modeling neural population dynamics with interpretable mechanistic priors and flexible data-driven components. It introduces a probabilistic framework of latent continuous-time stochastic differential equations (SDEs) and demonstrates a biophysically informed instantiation using coupled oscillators (CO-SDE), trained via variational inference with an augmented posterior and Girsanov-based KL terms. Across three neuroscience datasets, the hybrid oscillator-neural models achieve competitive predictive performance with far fewer parameters than deep baselines and provide principled uncertainty estimates and interpretable dynamical descriptions. This approach offers a scalable, interpretable, and uncertainty-aware toolkit for understanding neural computations and could inform online decoding and theoretical neuroscience.

Abstract

We propose a probabilistic framework for developing computational models of biological neural systems. In this framework, physiological recordings are viewed as discrete-time partial observations of an underlying continuous-time stochastic dynamical system which implements computations through its state evolution. To model this dynamical system, we employ a system of coupled stochastic differential equations with differentiable drift and diffusion functions and use variational inference to infer its states and parameters. This formulation enables seamless integration of existing mathematical models in the literature, neural networks, or a hybrid of both to learn and compare different models. We demonstrate this in our framework by developing a generative model that combines coupled oscillators with neural networks to capture latent population dynamics from single-cell recordings. Evaluation across three neuroscience datasets spanning different species, brain regions, and behavioral tasks show that these hybrid models achieve competitive performance in predicting stimulus-evoked neural and behavioral responses compared to sophisticated black-box approaches while requiring an order of magnitude fewer parameters, providing uncertainty estimates, and offering a natural language for interpretation.

Generative Modeling of Neural Dynamics via Latent Stochastic Differential Equations

TL;DR

The paper addresses the challenge of modeling neural population dynamics with interpretable mechanistic priors and flexible data-driven components. It introduces a probabilistic framework of latent continuous-time stochastic differential equations (SDEs) and demonstrates a biophysically informed instantiation using coupled oscillators (CO-SDE), trained via variational inference with an augmented posterior and Girsanov-based KL terms. Across three neuroscience datasets, the hybrid oscillator-neural models achieve competitive predictive performance with far fewer parameters than deep baselines and provide principled uncertainty estimates and interpretable dynamical descriptions. This approach offers a scalable, interpretable, and uncertainty-aware toolkit for understanding neural computations and could inform online decoding and theoretical neuroscience.

Abstract

We propose a probabilistic framework for developing computational models of biological neural systems. In this framework, physiological recordings are viewed as discrete-time partial observations of an underlying continuous-time stochastic dynamical system which implements computations through its state evolution. To model this dynamical system, we employ a system of coupled stochastic differential equations with differentiable drift and diffusion functions and use variational inference to infer its states and parameters. This formulation enables seamless integration of existing mathematical models in the literature, neural networks, or a hybrid of both to learn and compare different models. We demonstrate this in our framework by developing a generative model that combines coupled oscillators with neural networks to capture latent population dynamics from single-cell recordings. Evaluation across three neuroscience datasets spanning different species, brain regions, and behavioral tasks show that these hybrid models achieve competitive performance in predicting stimulus-evoked neural and behavioral responses compared to sophisticated black-box approaches while requiring an order of magnitude fewer parameters, providing uncertainty estimates, and offering a natural language for interpretation.

Paper Structure

This paper contains 20 sections, 20 equations, 3 figures, 2 tables.

Figures (3)

  • Figure 1: Schematic illustration of the framework.A The inference model. Experimental stimuli $v$ are processed by an input encoder ($\eta_{\theta}$) to produce a continuous-time input function $u$. Neural observations $y$ and behavioral observations $b$ are concatenated and processed by an observation encoder ($\gamma_{\phi}$) to produce a continuous-time context $c$ and the distribution parameters of an initial density $q(\tilde{x}_0 \mid y)$. An augmented SDE, serving as the approximate posterior, takes an initial condition $\tilde{x}_0$, $u_t$, and $c_t$ to infer latent trajectories $\tilde{x}_t$ at time $t > 0$. This augmented SDE is trained to match the generative SDE serving as our prior, which models latent trajectories $x_t$ based on an initial condition $x_0$ and $u_t$. The model parameters $\theta$ and variational parameters $\phi$ are learned by maximizing the evidence lower bound using variational inference. B The generation model. After training, given a stimulus, the input encoder ($\eta_{\theta}$) generates and encoded input signal $u$. The learned generative SDE uses $u(t)$ as a control signal and an initial state $x_0$ sampled from $p(x_0)$ to generate latent trajectories $x(t)$ (In the presence of historical observations, we sample the initial state from the approximate posterior conditioned on the observations). These trajectories are then passed to a neural decoder ($\lambda_{\theta}$) to predict neural activity $\hat{y}$ and a behavioral decoder ($\rho_{\theta}$) to predict behavioral responses $\hat{b}$.
  • Figure 2: Results on a simulated spiking neural system.a Five-fold cross validation results of comparing the performance of latent ODE vs latent SDE on fitting simulated data under different number of model latent states. b Five-fold cross validation results of different latent variable models under varying levels of process noise in the ground-truth system. c A sample of the spikes generated via the system vs spikes generated via the latent SDE model in response to the same input. d Samples of underlying firing rates generated via the system versus mean and standard deviation of 30 samples from a trained latent SDE model. e Phase portrait of the ground truth system versus the drift vector field of a trained latent SDE model at one time point under a fixed input.
  • Figure 3: Generative modeling of neural and behavioral data via latent coupled oscillators across three different datasets. (a-c) Perturbed reach task: The generative model takes as input the target location and manipulandum forces/torques and is tasked with predicting neural responses of 64 Neurons in Area 2 and the behavioral response of the monkey as measured via the cursor position. The inferred frequency of the coupled oscillators as well as the coupling strength (for a sample trial) for one training run is show in c. (d-f) Visual decision-making task: Input consists of contrast levels as well as the timing of presentation and the go cue. The model is trained to predict both neural activity across multiple brain regions and wheel velocity. (g,h,k) Delayed reach task: Unlike the other tasks, this model receives only preparatory neural activity before movement onset and is trained to predict subsequent neural activity in the dorsal pre-motor (PMd) and primary motor cortex (M1) during movement execution. Abbreviations: MOp: primary motor area, LSc: lateral sensory cortex, PT: posterior thalamus, CP: caudoputamen, LSr: Lateral sensory rostral area, PMd: dorsal pre-motor cortex, M1: primary motor cortex. Monkey and mouse illustrations adapted from scidraw.io (CC-BY 4.0)