Table of Contents
Fetching ...

Representation learning for neural population activity with Neural Data Transformers

Joel Ye, Chethan Pandarinath

TL;DR

The paper introduces the Neural Data Transformer (NDT), a Transformer-based, non-recurrent model for inferring neural population firing rates from spiking data. It demonstrates that parallelized attention can capture autonomous neural dynamics on both synthetic datasets and real motor cortex activity while delivering substantial speedups over recurrent baselines, enabling real-time inference. Through ablations and data-efficiency analyses, the study shows that log-rate outputs, zero masking, and strong regularization are critical for performance, and it identifies the limitation that the NDT struggles with non-autonomous inputs, suggesting avenues like hybrid architectures and multi-modal data. Overall, the NDT offers a scalable, fast alternative to RNNs for neuroscience applications, with implications for real-time brain-machine interfaces and large-scale neural-data modeling.

Abstract

Neural population activity is theorized to reflect an underlying dynamical structure. This structure can be accurately captured using state space models with explicit dynamics, such as those based on recurrent neural networks (RNNs). However, using recurrence to explicitly model dynamics necessitates sequential processing of data, slowing real-time applications such as brain-computer interfaces. Here we introduce the Neural Data Transformer (NDT), a non-recurrent alternative. We test the NDT's ability to capture autonomous dynamical systems by applying it to synthetic datasets with known dynamics and data from monkey motor cortex during a reaching task well-modeled by RNNs. The NDT models these datasets as well as state-of-the-art recurrent models. Further, its non-recurrence enables 3.9ms inference, well within the loop time of real-time applications and more than 6 times faster than recurrent baselines on the monkey reaching dataset. These results suggest that an explicit dynamics model is not necessary to model autonomous neural population dynamics. Code: https://github.com/snel-repo/neural-data-transformers

Representation learning for neural population activity with Neural Data Transformers

TL;DR

The paper introduces the Neural Data Transformer (NDT), a Transformer-based, non-recurrent model for inferring neural population firing rates from spiking data. It demonstrates that parallelized attention can capture autonomous neural dynamics on both synthetic datasets and real motor cortex activity while delivering substantial speedups over recurrent baselines, enabling real-time inference. Through ablations and data-efficiency analyses, the study shows that log-rate outputs, zero masking, and strong regularization are critical for performance, and it identifies the limitation that the NDT struggles with non-autonomous inputs, suggesting avenues like hybrid architectures and multi-modal data. Overall, the NDT offers a scalable, fast alternative to RNNs for neuroscience applications, with implications for real-time brain-machine interfaces and large-scale neural-data modeling.

Abstract

Neural population activity is theorized to reflect an underlying dynamical structure. This structure can be accurately captured using state space models with explicit dynamics, such as those based on recurrent neural networks (RNNs). However, using recurrence to explicitly model dynamics necessitates sequential processing of data, slowing real-time applications such as brain-computer interfaces. Here we introduce the Neural Data Transformer (NDT), a non-recurrent alternative. We test the NDT's ability to capture autonomous dynamical systems by applying it to synthetic datasets with known dynamics and data from monkey motor cortex during a reaching task well-modeled by RNNs. The NDT models these datasets as well as state-of-the-art recurrent models. Further, its non-recurrence enables 3.9ms inference, well within the loop time of real-time applications and more than 6 times faster than recurrent baselines on the monkey reaching dataset. These results suggest that an explicit dynamics model is not necessary to model autonomous neural population dynamics. Code: https://github.com/snel-repo/neural-data-transformers

Paper Structure

This paper contains 19 sections, 3 equations, 8 figures, 3 tables.

Figures (8)

  • Figure 1: Sequential vs. parallel models. (a) Unsupervised models of sequential spiking activity take in binned spikes (with 2 channels in this schematic) and output inferred rates. A likelihood loss trains the network to output the most likely rates. (b) A Transformer architecture (top) performs parallel modeling, contrasting with RNNs (bottom) and methods like GPFA which use sequential processing. (c) Spike input and inferred rate examples for the NDT applied to the reaching dataset. Rows are sorted by the time of each channel's maximum rate in the first trial, so as to demonstrate correspondence between observed activity and rates.
  • Figure 2: Transformer architecture. (a) A single Transformer layer, as in Fig. \ref{['fig:schematic']}. The full encoder stacks several of these layers. (b) Inputs to Transformer layers are normalized ("Norm" blocks), enriched through contextual information ("Self-Attention" blocks), and passed through a feedforward module ("MLP" multi-layer perceptron blocks). Blocks with the same label share parameters. The circled plus symbols indicate addition. (c) Inputs at each time step are multiplied by three learned weight matrices (not shown) to create three sets of vectors: the queries, keys, and values. (d) Assembling a single timestep's output: Dot products are computed between the timestep's query and every key, yielding similarity scores. Scores are normalized to sum to 1 to form weights. With these weights, a weighted sum of value vectors are computed and returned as the timestep's output.
  • Figure 3: Transformer training. The model is trained with masked modeling devlin2019bertneurips2019keshtkaran, that is, model outputs are optimized to maximize likelihood of the masked activity given the context provided by unmasked activity.
  • Figure 3: The use of logrates, zero masks, and heavy regularization are all critical to the NDT’s performance. As in main experiments, $\pm$ interval indicates SEM over 3 randomly initialized hyperparameter grid searches.
  • Figure 4: Modeling synthetic data. (a) In each quadrant, we plot the ground truth firing rate for a sampled neuron and information from 8 of its trials. We refer to these trials with the same initial conditions (and thus firing rates) as having the same condition and show 2 conditions (columns) for each of the synthetic datasets (rows). For these trials, we show the generated spikes (bottom) and inferred rates from AutoLFADS and NDT (top). Inferred firing rates closely match generating ground truth rates. Vertical bar denotes spikes per bin. Horizontal bar indicates 10% of the trial length, 5 bins for Lorenz and 10 bins for Chaotic RNN. (b) Across a hyperparameter sweep on the Lorenz dataset, models that achieve better likelihoods yield more accurate inference of the underlying rates. NLL is averaged across bins and channels.
  • ...and 3 more figures