Table of Contents
Fetching ...

JEDI: Jointly Embedded Inference of Neural Dynamics

Anirudh Jamkhandi, Ali Korojy, Olivier Codol, Guillaume Lajoie, Matthew G. Perich

TL;DR

JEDI is introduced, a hierarchical model that captures neural dynamics across tasks and contexts by learning a shared embedding space over RNN weights, and joint learning of contextual embeddings and recurrent weights provides scalable and generalizable inference of brain dynamics from recordings alone.

Abstract

Animal brains flexibly and efficiently achieve many behavioral tasks with a single neural network. A core goal in modern neuroscience is to map the mechanisms of the brain's flexibility onto the dynamics underlying neural populations. However, identifying task-specific dynamical rules from limited, noisy, and high-dimensional experimental neural recordings remains a major challenge, as experimental data often provide only partial access to brain states and dynamical mechanisms. While recurrent neural networks (RNNs) directly constrained neural data have been effective in inferring underlying dynamical mechanisms, they are typically limited to single-task domains and struggle to generalize across behavioral conditions. Here, we introduce JEDI, a hierarchical model that captures neural dynamics across tasks and contexts by learning a shared embedding space over RNN weights. This model recapitulates individual samples of neural dynamics while scaling to arbitrarily large and complex datasets, uncovering shared structure across conditions in a single, unified model. Using simulated RNN datasets, we demonstrate that JEDI accurately learns robust, generalizable, condition-specific embeddings. By reverse-engineering the weights learned by JEDI, we show that it recovers ground truth fixed point structures and unveils key features of the underlying neural dynamics in the eigenspectra. Finally, we apply JEDI to motor cortex recordings during monkey reaching to extract mechanistic insight into the neural dynamics of motor control. Our work shows that joint learning of contextual embeddings and recurrent weights provides scalable and generalizable inference of brain dynamics from recordings alone.

JEDI: Jointly Embedded Inference of Neural Dynamics

TL;DR

JEDI is introduced, a hierarchical model that captures neural dynamics across tasks and contexts by learning a shared embedding space over RNN weights, and joint learning of contextual embeddings and recurrent weights provides scalable and generalizable inference of brain dynamics from recordings alone.

Abstract

Animal brains flexibly and efficiently achieve many behavioral tasks with a single neural network. A core goal in modern neuroscience is to map the mechanisms of the brain's flexibility onto the dynamics underlying neural populations. However, identifying task-specific dynamical rules from limited, noisy, and high-dimensional experimental neural recordings remains a major challenge, as experimental data often provide only partial access to brain states and dynamical mechanisms. While recurrent neural networks (RNNs) directly constrained neural data have been effective in inferring underlying dynamical mechanisms, they are typically limited to single-task domains and struggle to generalize across behavioral conditions. Here, we introduce JEDI, a hierarchical model that captures neural dynamics across tasks and contexts by learning a shared embedding space over RNN weights. This model recapitulates individual samples of neural dynamics while scaling to arbitrarily large and complex datasets, uncovering shared structure across conditions in a single, unified model. Using simulated RNN datasets, we demonstrate that JEDI accurately learns robust, generalizable, condition-specific embeddings. By reverse-engineering the weights learned by JEDI, we show that it recovers ground truth fixed point structures and unveils key features of the underlying neural dynamics in the eigenspectra. Finally, we apply JEDI to motor cortex recordings during monkey reaching to extract mechanistic insight into the neural dynamics of motor control. Our work shows that joint learning of contextual embeddings and recurrent weights provides scalable and generalizable inference of brain dynamics from recordings alone.
Paper Structure (34 sections, 5 equations, 17 figures)

This paper contains 34 sections, 5 equations, 17 figures.

Figures (17)

  • Figure 1: JEDI leverages a hypernetwork-based framework to flexibly generate RNN weights based on contextual inputs, all learned directly from a loss computed against the RNN output and time series neural recordings.
  • Figure 2: Quantifying the quality of the embeddings. A) Synthetic multi-task data generation setup. B) 2D PCA visualization of context embeddings. Each point corresponds to a trial, color-coded by input signal type. C) Training reconstruction accuracy $R^2$ across different methods. D) Accuracy of task classification from the learned embeddings. E) Generalization accuracy $R^2$ from the center of the learned embedding for each task F) Confusion matrix of generalization $R^2$ scores, expanding the results in Panel E. Rows represent training tasks, and columns indicate test tasks.
  • Figure 3: Impact of embedding noise on model performance, comparing reconstruction $R^2$ as increasing noise is applied to the embeddings. The added noise was scaled according to the standard deviations of embeddings for each model
  • Figure 4: A) We drove the chaotic Teacher RNN with sinusoidal inputs at different frequencies (1–10 Hz). B) The eigen spectra of the weights inferred with JEDI exhibit a characteristic expansion along the imaginary axis as input frequency increases.
  • Figure 5: JEDI identifies fixed point structure in task-trained networks. A) We fit JEDI to a network trained to perform the MemoryPro task. This task has four contextual trial phases. B) Context embedding learned by JEDI colored by trial phase (left) and response direction (right). C) Fixed point structure for the four trial phases for the ground-truth task-trained network. D) Fixed point structure inferred by JEDI.
  • ...and 12 more figures