Representation learning for neural population activity with Neural Data Transformers
Joel Ye, Chethan Pandarinath
TL;DR
The paper introduces the Neural Data Transformer (NDT), a Transformer-based, non-recurrent model for inferring neural population firing rates from spiking data. It demonstrates that parallelized attention can capture autonomous neural dynamics on both synthetic datasets and real motor cortex activity while delivering substantial speedups over recurrent baselines, enabling real-time inference. Through ablations and data-efficiency analyses, the study shows that log-rate outputs, zero masking, and strong regularization are critical for performance, and it identifies the limitation that the NDT struggles with non-autonomous inputs, suggesting avenues like hybrid architectures and multi-modal data. Overall, the NDT offers a scalable, fast alternative to RNNs for neuroscience applications, with implications for real-time brain-machine interfaces and large-scale neural-data modeling.
Abstract
Neural population activity is theorized to reflect an underlying dynamical structure. This structure can be accurately captured using state space models with explicit dynamics, such as those based on recurrent neural networks (RNNs). However, using recurrence to explicitly model dynamics necessitates sequential processing of data, slowing real-time applications such as brain-computer interfaces. Here we introduce the Neural Data Transformer (NDT), a non-recurrent alternative. We test the NDT's ability to capture autonomous dynamical systems by applying it to synthetic datasets with known dynamics and data from monkey motor cortex during a reaching task well-modeled by RNNs. The NDT models these datasets as well as state-of-the-art recurrent models. Further, its non-recurrence enables 3.9ms inference, well within the loop time of real-time applications and more than 6 times faster than recurrent baselines on the monkey reaching dataset. These results suggest that an explicit dynamics model is not necessary to model autonomous neural population dynamics. Code: https://github.com/snel-repo/neural-data-transformers
