Table of Contents
Fetching ...

Population Transformer: Learning Population-level Representations of Neural Activity

Geeling Chau, Christopher Wang, Sabera Talukder, Vighnesh Subramaniam, Saraswati Soedarmadji, Yisong Yue, Boris Katz, Andrei Barbu

TL;DR

This paper tackles learning population-level representations from neural time-series data with highly variable electrode configurations. It introduces Population Transformer (PopT), a modular transformer-based spatial aggregator that sits on top of frozen temporal embeddings and is pretrained with two discriminative self-supervised objectives. Pretraining yields subject-generic, spatial-contextual channel representations that improve downstream decoding across iEEG and EEG tasks while reducing data and compute requirements, and generalizes to unseen subjects. The authors also provide interpretability tools to map connectivity and candidate functional brain regions from PopT weights, and release pretrained models and code for off-the-shelf use in multi-channel neural decoding.

Abstract

We present a self-supervised framework that learns population-level codes for arbitrary ensembles of neural recordings at scale. We address key challenges in scaling models with neural time-series data, namely, sparse and variable electrode distribution across subjects and datasets. The Population Transformer (PopT) stacks on top of pretrained temporal embeddings and enhances downstream decoding by enabling learned aggregation of multiple spatially-sparse data channels. The pretrained PopT lowers the amount of data required for downstream decoding experiments, while increasing accuracy, even on held-out subjects and tasks. Compared to end-to-end methods, this approach is computationally lightweight, while achieving similar or better decoding performance. We further show how our framework is generalizable to multiple time-series embeddings and neural data modalities. Beyond decoding, we interpret the pretrained and fine-tuned PopT models to show how they can be used to extract neuroscience insights from large amounts of data. We release our code as well as a pretrained PopT to enable off-the-shelf improvements in multi-channel intracranial data decoding and interpretability. Code is available at https://github.com/czlwang/PopulationTransformer.

Population Transformer: Learning Population-level Representations of Neural Activity

TL;DR

This paper tackles learning population-level representations from neural time-series data with highly variable electrode configurations. It introduces Population Transformer (PopT), a modular transformer-based spatial aggregator that sits on top of frozen temporal embeddings and is pretrained with two discriminative self-supervised objectives. Pretraining yields subject-generic, spatial-contextual channel representations that improve downstream decoding across iEEG and EEG tasks while reducing data and compute requirements, and generalizes to unseen subjects. The authors also provide interpretability tools to map connectivity and candidate functional brain regions from PopT weights, and release pretrained models and code for off-the-shelf use in multi-channel neural decoding.

Abstract

We present a self-supervised framework that learns population-level codes for arbitrary ensembles of neural recordings at scale. We address key challenges in scaling models with neural time-series data, namely, sparse and variable electrode distribution across subjects and datasets. The Population Transformer (PopT) stacks on top of pretrained temporal embeddings and enhances downstream decoding by enabling learned aggregation of multiple spatially-sparse data channels. The pretrained PopT lowers the amount of data required for downstream decoding experiments, while increasing accuracy, even on held-out subjects and tasks. Compared to end-to-end methods, this approach is computationally lightweight, while achieving similar or better decoding performance. We further show how our framework is generalizable to multiple time-series embeddings and neural data modalities. Beyond decoding, we interpret the pretrained and fine-tuned PopT models to show how they can be used to extract neuroscience insights from large amounts of data. We release our code as well as a pretrained PopT to enable off-the-shelf improvements in multi-channel intracranial data decoding and interpretability. Code is available at https://github.com/czlwang/PopulationTransformer.
Paper Structure (20 sections, 17 figures, 8 tables, 1 algorithm)

This paper contains 20 sections, 17 figures, 8 tables, 1 algorithm.

Figures (17)

  • Figure 1: Schematic of our approach. The inputs to our model (a) are the neural activities from a collection of electrodes in a given time interval (bottom). These are passed to a frozen temporal embedding model (dotted red outline: BrainBERT wang2023brainbert shown), which produces a set of time embedding vectors (yellow). The 3D positions of each electrode (red) are summed with these vectors to produce the model inputs (orange, lower). PopT produces space-contextual embeddings (orange, top) for each electrode and a [CLS] token (blue, top), which can be fine-tuned for downstream tasks. In pretraining, PopT learns two objectives simultaneously. In the first, (b) PopT determines whether two different sets of electrodes (orange vs brown) represent consecutive or non-consecutive times. In the second objective, (c) PopT must determine whether an input channel has been replaced with activity at a random other time that is inconsistent with the majority of inputs.
  • Figure 2: Compared to common aggregation approaches, pretrained PopT consistently yields better downstream decoding across tasks, data modalities, and temporal embedding types. NPopT = Non-pretrained PopT. (a) performance on four audio-linguistic iEEG tasks with 90 electrodes. Grey bars denote standard error across subjects. (b) performance on an abnormal detection EEG task with 21 electrodes. Grey bars denote standard deviation across 5 random seeds.
  • Figure 3: Pretrained PopT downstream performance scales better with ensemble size. Increasing channel ensemble size from 1 to 50 (x-axis), we see pretrained PopT (green) decoding performance (y-axis) not only beat non-pretrained approaches (orange, purple, grey), but also continually improve more with increasing channel count. Shaded bands show the standard error across subjects.
  • Figure 4: Pretrained PopT is more sample efficient when fine-tuning. Varying the number of samples available to each model at train time (x-axis), we see that the pretrained PopT is highly sample efficient, requiring only a fraction of samples (fewer than 500 samples out of 5-10k of the full dataset) to reach the full performance level of baseline aggregation approaches (dashed lines). Bands show standard error across test subjects. Stars indicate performance with full fine-tuning dataset.
  • Figure 5: Pretrained PopT is consistently compute efficient when fine-tuning. Number of steps required for each model to reach final performance during fine-tuning (dashed lines). The pretrained PopT consistently requires fewer than 750 steps (each step is an update on a batch size of 256) to converge. Bands show standard error across subjects. Stars indicate fully trained performance.
  • ...and 12 more figures