Table of Contents
Fetching ...

BrainNPT: Pre-training of Transformer networks for brain network classification

Jinlong Hu, Yangmin Huang, Nan Wang, Shoubin Dong

TL;DR

BrainNPT introduces a Transformer-based architecture for brain functional network classification that uses a learnable <cls> token to summarize network structure. A novel pre-training framework, featuring a Replaced ROI Prediction task and sliding-window data augmentation, enables effective learning from unlabeled rs-fMRI data, producing substantial downstream gains (e.g., an approximate $8.75\%$ accuracy improvement on ABIDE II). Across ABIDE II and REST-meta-MDD, pre-training with RRP consistently boosts performance over non-pre-trained baselines and BrainNetTF baselines, while LRP-based interpretation identifies biologically plausible ROIs linked to autism and default-mode network disruptions. The work demonstrates the practical impact of Transformer-based pre-training in neuroscience, highlighting the importance of pre-training scale and offering interpretable insights into brain network organization.

Abstract

Deep learning methods have advanced quickly in brain imaging analysis over the past few years, but they are usually restricted by the limited labeled data. Pre-trained model on unlabeled data has presented promising improvement in feature learning in many domains, including natural language processing and computer vision. However, this technique is under-explored in brain network analysis. In this paper, we focused on pre-training methods with Transformer networks to leverage existing unlabeled data for brain functional network classification. First, we proposed a Transformer-based neural network, named as BrainNPT, for brain functional network classification. The proposed method leveraged <cls> token as a classification embedding vector for the Transformer model to effectively capture the representation of brain network. Second, we proposed a pre-training framework for BrainNPT model to leverage unlabeled brain network data to learn the structure information of brain networks. The results of classification experiments demonstrated the BrainNPT model without pre-training achieved the best performance with the state-of-the-art models, and the BrainNPT model with pre-training strongly outperformed the state-of-the-art models. The pre-training BrainNPT model improved 8.75% of accuracy compared with the model without pre-training. We further compared the pre-training strategies, analyzed the influence of the parameters of the model, and interpreted the trained model.

BrainNPT: Pre-training of Transformer networks for brain network classification

TL;DR

BrainNPT introduces a Transformer-based architecture for brain functional network classification that uses a learnable <cls> token to summarize network structure. A novel pre-training framework, featuring a Replaced ROI Prediction task and sliding-window data augmentation, enables effective learning from unlabeled rs-fMRI data, producing substantial downstream gains (e.g., an approximate accuracy improvement on ABIDE II). Across ABIDE II and REST-meta-MDD, pre-training with RRP consistently boosts performance over non-pre-trained baselines and BrainNetTF baselines, while LRP-based interpretation identifies biologically plausible ROIs linked to autism and default-mode network disruptions. The work demonstrates the practical impact of Transformer-based pre-training in neuroscience, highlighting the importance of pre-training scale and offering interpretable insights into brain network organization.

Abstract

Deep learning methods have advanced quickly in brain imaging analysis over the past few years, but they are usually restricted by the limited labeled data. Pre-trained model on unlabeled data has presented promising improvement in feature learning in many domains, including natural language processing and computer vision. However, this technique is under-explored in brain network analysis. In this paper, we focused on pre-training methods with Transformer networks to leverage existing unlabeled data for brain functional network classification. First, we proposed a Transformer-based neural network, named as BrainNPT, for brain functional network classification. The proposed method leveraged <cls> token as a classification embedding vector for the Transformer model to effectively capture the representation of brain network. Second, we proposed a pre-training framework for BrainNPT model to leverage unlabeled brain network data to learn the structure information of brain networks. The results of classification experiments demonstrated the BrainNPT model without pre-training achieved the best performance with the state-of-the-art models, and the BrainNPT model with pre-training strongly outperformed the state-of-the-art models. The pre-training BrainNPT model improved 8.75% of accuracy compared with the model without pre-training. We further compared the pre-training strategies, analyzed the influence of the parameters of the model, and interpreted the trained model.
Paper Structure (28 sections, 4 equations, 8 figures, 9 tables)

This paper contains 28 sections, 4 equations, 8 figures, 9 tables.

Figures (8)

  • Figure 1: The architecture of BrainNPT. The architecture includes classification embedding vector <cls>, Transformer block, and MLP block. The Transformer block could be stacked into multiple layers.
  • Figure 2: The framework of pre-training and fine-tuning of BrainNPT. The pre-training part used randomly replaced ROI strategy in the left block, and the fine-tuning part used the parameters from pre-training for downstream tasks in the right block.
  • Figure 3: The framework of pre-training with RRP.
  • Figure 4: The framework of pre-training for MRM.
  • Figure 5: The training and testing evaluation of RRP model. The training losses tend to stabilize after about 45 rounds, and the accuracy of predicting replaced ROIs with RRP in testing is about 96%.
  • ...and 3 more figures