Table of Contents
Fetching ...

Generalizable Representation Learning for fMRI-based Neurological Disorder Identification

Wenhui Cui, Haleh Akrami, Anand A. Joshi, Richard M. Leahy

TL;DR

This work tackles the challenge of limited and heterogeneous clinical fMRI data for neurological disorder identification. It introduces MeTSK, a representation learning framework that combines self-supervised learning on large control datasets with bi-level meta-learning to transfer representations to scarce clinical domains. Through linear probing, MeTSK demonstrates superior generalization across unseen clinical datasets and multiple disorders, outperforming baseline transfer methods and recent foundation-model approaches. The study provides code and highlights practical impact for robust, data-efficient neurodiagnostics in real-world clinical settings.

Abstract

Despite the impressive advances achieved using deep learning for functional brain activity analysis, the heterogeneity of functional patterns and the scarcity of imaging data still pose challenges in tasks such as identifying neurological disorders. For functional Magnetic Resonance Imaging (fMRI), while data may be abundantly available from healthy controls, clinical data is often scarce, especially for rare diseases, limiting the ability of models to identify clinically-relevant features. We overcome this limitation by introducing a novel representation learning strategy integrating meta-learning with self-supervised learning to improve the generalization from normal to clinical features. This approach enables generalization to challenging clinical tasks featuring scarce training data. We achieve this by leveraging self-supervised learning on the control dataset to focus on inherent features that are not limited to a particular supervised task and incorporating meta-learning to improve the generalization across domains. To explore the generalizability of the learned representations to unseen clinical applications, we apply the model to four distinct clinical datasets featuring scarce and heterogeneous data for neurological disorder classification. Results demonstrate the superiority of our representation learning strategy on diverse clinically-relevant tasks. Code is publicly available at https://github.com/wenhui0206/MeTSK/tree/main

Generalizable Representation Learning for fMRI-based Neurological Disorder Identification

TL;DR

This work tackles the challenge of limited and heterogeneous clinical fMRI data for neurological disorder identification. It introduces MeTSK, a representation learning framework that combines self-supervised learning on large control datasets with bi-level meta-learning to transfer representations to scarce clinical domains. Through linear probing, MeTSK demonstrates superior generalization across unseen clinical datasets and multiple disorders, outperforming baseline transfer methods and recent foundation-model approaches. The study provides code and highlights practical impact for robust, data-efficient neurodiagnostics in real-world clinical settings.

Abstract

Despite the impressive advances achieved using deep learning for functional brain activity analysis, the heterogeneity of functional patterns and the scarcity of imaging data still pose challenges in tasks such as identifying neurological disorders. For functional Magnetic Resonance Imaging (fMRI), while data may be abundantly available from healthy controls, clinical data is often scarce, especially for rare diseases, limiting the ability of models to identify clinically-relevant features. We overcome this limitation by introducing a novel representation learning strategy integrating meta-learning with self-supervised learning to improve the generalization from normal to clinical features. This approach enables generalization to challenging clinical tasks featuring scarce training data. We achieve this by leveraging self-supervised learning on the control dataset to focus on inherent features that are not limited to a particular supervised task and incorporating meta-learning to improve the generalization across domains. To explore the generalizability of the learned representations to unseen clinical applications, we apply the model to four distinct clinical datasets featuring scarce and heterogeneous data for neurological disorder classification. Results demonstrate the superiority of our representation learning strategy on diverse clinically-relevant tasks. Code is publicly available at https://github.com/wenhui0206/MeTSK/tree/main

Paper Structure

This paper contains 34 sections, 8 equations, 6 figures, 10 tables.

Figures (6)

  • Figure 1: An illustration of MeTSK for generalizable representation learning. In MeTSK, two optimization loops are involved in training. The inner loop only updates the target head, while the outer loop updates the source head and feature extractor. The representation learning pipeline involves first training MeTSK on the source and target data, and then evaluate the learned representations on unseen clinical datasets using neurological disorder classification tasks.
  • Figure 2: An illustration of the ST-GCN model architecture. Spatial graph convolution is first applied to the spatial graph at each time point. Then 1D temporal convolutions are performed along the resulting features on each node. Multiple sub-sequences are randomly sampled from the whole time series as input graphs.
  • Figure 3: Feature importance map of PTE features generated from MeTSK shown as color-coded ROIs overlaid on the AAL atlas. The numbers represent the absolute value of coefficients from the trained SVM.
  • Figure 4: A comparison of the linear probing performance between MeTSK model trained with a single target site (MeTSK-ADHD-Peking, results from Table \ref{['ad']}) and MeTSK model trained using all the target sites from ADHD-200 dataset (MeTSK-ADHD-All). Also shown are results from linear classifiers trained directly using connectivity features for baseline comparison. The height of each bar indicates the average AUC computed from a 5-fold cross-validation, while the error bars denote the standard deviations. MeTSK-ADHD-All achieved similar performance to MeTSK-ADHD-Peking on all the four downstream datasets.
  • Figure 5: Loss evolution during MeTSK training over 90 epochs. The first 20 epochs is a warm-up phase where only the self-supervised task is trained. All loss values are averaged over batches and recorded every 2 epochs. Left: Classification loss on the meta-training set after the inner-loop updates in Step 2. We record the loss at the last update step in each training iteration. This loss remains within a stable range and gradually decreases, indicating that the target head effectively adapts using features trained in Step 3 without negative interference. Middle: Self-supervised loss from Step 3, showing a steady decline. Right: Meta-validation classification loss in Step 3, computed from the frozen target head, exhibits a generally decreasing trend with some fluctuations.
  • ...and 1 more figures