Table of Contents
Fetching ...

Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking

Zijian Dong, Ruilin Li, Yilei Wu, Thuan Tinh Nguyen, Joanna Su Xian Chong, Fang Ji, Nathanael Ren Jie Tong, Christopher Li Hsian Chen, Juan Helen Zhou

TL;DR

This pioneering model achieves state-of-the-art performance in demographic prediction, disease diagnosis/prognosis, and trait prediction through fine-tuning and demonstrates superior generalizability across different ethnic groups, surpassing the previous large model for brain activity significantly.

Abstract

We introduce Brain-JEPA, a brain dynamics foundation model with the Joint-Embedding Predictive Architecture (JEPA). This pioneering model achieves state-of-the-art performance in demographic prediction, disease diagnosis/prognosis, and trait prediction through fine-tuning. Furthermore, it excels in off-the-shelf evaluations (e.g., linear probing) and demonstrates superior generalizability across different ethnic groups, surpassing the previous large model for brain activity significantly. Brain-JEPA incorporates two innovative techniques: Brain Gradient Positioning and Spatiotemporal Masking. Brain Gradient Positioning introduces a functional coordinate system for brain functional parcellation, enhancing the positional encoding of different Regions of Interest (ROIs). Spatiotemporal Masking, tailored to the unique characteristics of fMRI data, addresses the challenge of heterogeneous time-series patches. These methodologies enhance model performance and advance our understanding of the neural circuits underlying cognition. Overall, Brain-JEPA is paving the way to address pivotal questions of building brain functional coordinate system and masking brain activity at the AI-neuroscience interface, and setting a potentially new paradigm in brain activity analysis through downstream adaptation.

Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking

TL;DR

This pioneering model achieves state-of-the-art performance in demographic prediction, disease diagnosis/prognosis, and trait prediction through fine-tuning and demonstrates superior generalizability across different ethnic groups, surpassing the previous large model for brain activity significantly.

Abstract

We introduce Brain-JEPA, a brain dynamics foundation model with the Joint-Embedding Predictive Architecture (JEPA). This pioneering model achieves state-of-the-art performance in demographic prediction, disease diagnosis/prognosis, and trait prediction through fine-tuning. Furthermore, it excels in off-the-shelf evaluations (e.g., linear probing) and demonstrates superior generalizability across different ethnic groups, surpassing the previous large model for brain activity significantly. Brain-JEPA incorporates two innovative techniques: Brain Gradient Positioning and Spatiotemporal Masking. Brain Gradient Positioning introduces a functional coordinate system for brain functional parcellation, enhancing the positional encoding of different Regions of Interest (ROIs). Spatiotemporal Masking, tailored to the unique characteristics of fMRI data, addresses the challenge of heterogeneous time-series patches. These methodologies enhance model performance and advance our understanding of the neural circuits underlying cognition. Overall, Brain-JEPA is paving the way to address pivotal questions of building brain functional coordinate system and masking brain activity at the AI-neuroscience interface, and setting a potentially new paradigm in brain activity analysis through downstream adaptation.
Paper Structure (29 sections, 6 equations, 7 figures, 13 tables)

This paper contains 29 sections, 6 equations, 7 figures, 13 tables.

Figures (7)

  • Figure 1: Brain-JEPA. With a Vision Transformer (ViT) as the observation encoder $f_{\theta}$, Brain-JEPA employs a single observation block to predict the representations of target blocks. (1) The input fMRI data is initially segmented into patches for subsequent processing. (2) Through Spatiotemporal Masking, the input data—excluding the observation block—is divided into three distinct regions: Cross-ROI ($\alpha$), Cross-Time ($\beta$), and Double-Cross ($\gamma$). The target blocks are sampled from different regions separately. (3) A narrower ViT, serving as the predictor $g_{\phi}$, takes the output $\boldsymbol{s}_x$ from $f_{\theta}$. It predicts the representations of a target block $\hat{\boldsymbol{s}}^{r}_y$ conditioned on positional embedding (brain gradient positioning for ROI locations and sine and cosine functions for temporal positioning). (4) These predicted representations align with those $\boldsymbol{s}^{r}_y$ from the target encoder $f_{\overline{\theta}}$, whose parameters are incrementally updated through an Exponential Moving Average (EMA) of the observation encoder's parameters.
  • Figure 2: Brain gradient positioning. Brain cortical regions are situated in the top 3 gradient axes and colored based on their positions. These colors are then projected back into the brain surface.
  • Figure 3: Performance scaling of the model sizes.
  • Figure 4: Fine-tuning v.s. linear probing.
  • Figure 5: Comparisons of spatial positional embedding (For the first task, refer to the left $y$ axis for the Pearson's Correlation, with the right $y$ axis accuracy for the last two tasks).
  • ...and 2 more figures