Table of Contents
Fetching ...

Towards a "universal translator" for neural dynamics at single-cell, single-spike resolution

Yizi Zhang, Yanchen Wang, Donato Jimenez-Beneto, Zixuan Wang, Mehdi Azabou, Blake Richards, Olivier Winter, International Brain Laboratory, Eva Dyer, Liam Paninski, Cole Hurwitz

TL;DR

This work addresses the challenge of building a foundation model that can read neural activity across brain regions at single-cell, single-spike resolution. It introduces Multi-task-masking (MtM), a self-supervised training framework that alternates among masking schemes (causal, neuron, intra-region, inter-region) and uses a learnable prompt for mode switching at inference, enabling cross-scale representation learning. Evaluated on the International Brain Laboratory Neuropixels dataset, MtM consistently outperforms strong baselines (temporal masking) and shows strong multi-task generalization, especially when pretraining spans many animals and sessions; behavior decoding improves notably when conditioning on region-specific structure, with scaling benefits evident up to 34 pretraining sessions. The results suggest MtM as a viable path toward a universal translator of brain dynamics, capable of leveraging diverse, multi-region neural data to generalize to unseen animals and tasks, ultimately enabling broader, brain-wide analyses at fine temporal and cellular resolution.

Abstract

Neuroscience research has made immense progress over the last decade, but our understanding of the brain remains fragmented and piecemeal: the dream of probing an arbitrary brain region and automatically reading out the information encoded in its neural activity remains out of reach. In this work, we build towards a first foundation model for neural spiking data that can solve a diverse set of tasks across multiple brain areas. We introduce a novel self-supervised modeling approach for population activity in which the model alternates between masking out and reconstructing neural activity across different time steps, neurons, and brain regions. To evaluate our approach, we design unsupervised and supervised prediction tasks using the International Brain Laboratory repeated site dataset, which is comprised of Neuropixels recordings targeting the same brain locations across 48 animals and experimental sessions. The prediction tasks include single-neuron and region-level activity prediction, forward prediction, and behavior decoding. We demonstrate that our multi-task-masking (MtM) approach significantly improves the performance of current state-of-the-art population models and enables multi-task learning. We also show that by training on multiple animals, we can improve the generalization ability of the model to unseen animals, paving the way for a foundation model of the brain at single-cell, single-spike resolution.

Towards a "universal translator" for neural dynamics at single-cell, single-spike resolution

TL;DR

This work addresses the challenge of building a foundation model that can read neural activity across brain regions at single-cell, single-spike resolution. It introduces Multi-task-masking (MtM), a self-supervised training framework that alternates among masking schemes (causal, neuron, intra-region, inter-region) and uses a learnable prompt for mode switching at inference, enabling cross-scale representation learning. Evaluated on the International Brain Laboratory Neuropixels dataset, MtM consistently outperforms strong baselines (temporal masking) and shows strong multi-task generalization, especially when pretraining spans many animals and sessions; behavior decoding improves notably when conditioning on region-specific structure, with scaling benefits evident up to 34 pretraining sessions. The results suggest MtM as a viable path toward a universal translator of brain dynamics, capable of leveraging diverse, multi-region neural data to generalize to unseen animals and tasks, ultimately enabling broader, brain-wide analyses at fine temporal and cellular resolution.

Abstract

Neuroscience research has made immense progress over the last decade, but our understanding of the brain remains fragmented and piecemeal: the dream of probing an arbitrary brain region and automatically reading out the information encoded in its neural activity remains out of reach. In this work, we build towards a first foundation model for neural spiking data that can solve a diverse set of tasks across multiple brain areas. We introduce a novel self-supervised modeling approach for population activity in which the model alternates between masking out and reconstructing neural activity across different time steps, neurons, and brain regions. To evaluate our approach, we design unsupervised and supervised prediction tasks using the International Brain Laboratory repeated site dataset, which is comprised of Neuropixels recordings targeting the same brain locations across 48 animals and experimental sessions. The prediction tasks include single-neuron and region-level activity prediction, forward prediction, and behavior decoding. We demonstrate that our multi-task-masking (MtM) approach significantly improves the performance of current state-of-the-art population models and enables multi-task learning. We also show that by training on multiple animals, we can improve the generalization ability of the model to unseen animals, paving the way for a foundation model of the brain at single-cell, single-spike resolution.
Paper Structure (40 sections, 1 equation, 10 figures, 7 tables)

This paper contains 40 sections, 1 equation, 10 figures, 7 tables.

Figures (10)

  • Figure 1: Schematic illustration of our Multi-task-Masking (MtM) approach: (A) We introduce four metrics for evaluating foundation models of neural population activity: neuron co-smoothing, causal prediction, inter-region prediction, and intra-region prediction. For each masking scheme, the colored area indicates what is masked and then reconstructed for evaluation. For intra-region prediction, the colored areas with hatched lines indicate areas which are masked, but not reconstructed for evaluation. Each metric can be associated with a specific masking scheme during training (T1, T2, etc.). (B) We alternate between different masking schemes during training along with a learnable "prompt" token which provides context to the model about the associated task tay2022ul2. During evaluation, we provide the associated prompt token for the downstream task to perform test-time adaptation of the model. Our MtM approach is architecture-agnostic as masking is performed on the input data (not the tokens). For a full discussion of MtM, see Section \ref{['sec:methods']}.
  • Figure 1: The performance of single-session NDT1 trained with various masking schemes on neural activity reconstruction tasks. The metrics are in units of bits per spike (bps), averaged across all neurons in one session. A higher bps value indicates better performance.
  • Figure 2: Comparison of the temporal masking baseline and our proposed MtM method on single-session data. (A) and (B) show trial-averaged raster maps of CA1 for ground-truth data, MtM, and the temporal baseline. (A) The predictions from MtM and the temporal baseline are after inter-region masking where neurons in CA1 are predicted from all other brain regions. We highlight two areas (1 and 2) where MtM shows qualitatively better predictions of activity. (B) The predictions from MtM and the temporal baseline are after intra-region masking where all neurons in CA1 are predicted from other neurons in the same brain region. We again highlight two areas (1 and 2) where MtM shows qualitatively better predictions of activity. (C) Activity reconstruction and behavior decoding across 39 sessions for MtM and temporal masking. Each point represents one session. For activity reconstruction, we report the average bps. For choice and whisker motion energy decoding, we report the average accuracy and $R^2$, respectively, across all test trials. We use the NDT1 architecture for all comparisons.
  • Figure 3: Fine-tuning performance comparison of NDT1-stitch pretrained with MtM vs. temporal masking for activity reconstruction and behavior decoding across 5 held-out sessions. For activity reconstruction, each point shows the average bps across all neurons in a held-out session. For behavior decoding, each point shows the trial-averaged accuracy (choice) and $R^2$ (WME).
  • Figure 4: Comparison of scaling curves between NDT1-stitch pretrained with the MtM method vs. the temporal masking baseline. The reported metrics - neuron-averaged bits per spike (bps), choice decoding accuracy, and whisker motion energy decoding $R^2$ - are averaged over all 5 held-out sessions. We fine-tune each pretrained model with its self-supervised loss (MtM or temporal) on the 5-heldout sessions and then evaluate with all of our metrics. "Num of Sessions" denotes the number of sessions used for pretraining.
  • ...and 5 more figures