Table of Contents
Fetching ...

Federated Fine-tuning of SAM-Med3D for MRI-based Dementia Classification

Kaouther Mouheb, Marawan Elbatel, Janne Papma, Geert Jan Biessels, Jurgen Claassen, Huub Middelkoop, Barbara van Munster, Wiesje van der Flier, Inez Ramakers, Stefan Klein, Esther E. Bron

TL;DR

This work addresses how to effectively fine-tune large 3D foundation models in a federated setting for MRI-based dementia classification, evaluating the impact of classification head design, fine-tuning strategy, and aggregation method on diagnostic performance and efficiency. Using SAM-Med3D as the backbone across a large, heterogeneous multi-cohort dataset, the study demonstrates that convolutional classification heads substantially improve accuracy, freezing the encoder often matches full fine-tuning performance while reducing cost, and advanced aggregation (FedCE, Rate-My-LoRA) can approach or beat centralized fine-tuning by mitigating client heterogeneity. The authors provide an open-source framework for federated 3D FM evaluation and offer actionable guidance: prefer compact convolutional heads for efficient communication, consider encoder freezing in FL, and employ advanced aggregation to boost cross-site performance, especially for data-rich cohorts. These findings support practical deployment of federated FMs in decentralized clinical settings and point to future directions in FL theory and 3D FM development with privacy-preserving medical imaging.

Abstract

While foundation models (FMs) offer strong potential for AI-based dementia diagnosis, their integration into federated learning (FL) systems remains underexplored. In this benchmarking study, we systematically evaluate the impact of key design choices: classification head architecture, fine-tuning strategy, and aggregation method, on the performance and efficiency of federated FM tuning using brain MRI data. Using a large multi-cohort dataset, we find that the architecture of the classification head substantially influences performance, freezing the FM encoder achieves comparable results to full fine-tuning, and advanced aggregation methods outperform standard federated averaging. Our results offer practical insights for deploying FMs in decentralized clinical settings and highlight trade-offs that should guide future method development.

Federated Fine-tuning of SAM-Med3D for MRI-based Dementia Classification

TL;DR

This work addresses how to effectively fine-tune large 3D foundation models in a federated setting for MRI-based dementia classification, evaluating the impact of classification head design, fine-tuning strategy, and aggregation method on diagnostic performance and efficiency. Using SAM-Med3D as the backbone across a large, heterogeneous multi-cohort dataset, the study demonstrates that convolutional classification heads substantially improve accuracy, freezing the encoder often matches full fine-tuning performance while reducing cost, and advanced aggregation (FedCE, Rate-My-LoRA) can approach or beat centralized fine-tuning by mitigating client heterogeneity. The authors provide an open-source framework for federated 3D FM evaluation and offer actionable guidance: prefer compact convolutional heads for efficient communication, consider encoder freezing in FL, and employ advanced aggregation to boost cross-site performance, especially for data-rich cohorts. These findings support practical deployment of federated FMs in decentralized clinical settings and point to future directions in FL theory and 3D FM development with privacy-preserving medical imaging.

Abstract

While foundation models (FMs) offer strong potential for AI-based dementia diagnosis, their integration into federated learning (FL) systems remains underexplored. In this benchmarking study, we systematically evaluate the impact of key design choices: classification head architecture, fine-tuning strategy, and aggregation method, on the performance and efficiency of federated FM tuning using brain MRI data. Using a large multi-cohort dataset, we find that the architecture of the classification head substantially influences performance, freezing the FM encoder achieves comparable results to full fine-tuning, and advanced aggregation methods outperform standard federated averaging. Our results offer practical insights for deploying FMs in decentralized clinical settings and highlight trade-offs that should guide future method development.

Paper Structure

This paper contains 15 sections, 1 equation, 5 figures, 3 tables.

Figures (5)

  • Figure 1: The framework explores three design choices: (1) Classification Head Architecture: linear, small CNN adapter (CONV S), and large CNN adapter (CONV L); (2) Fine-tuning Method: full model tuning, classifier-only (linear probing), and LoRA with selective attention block adaptation; (3) Aggregation Strategy: including SimpleAvg, FedAvg, and the advanced methods FedCE and Rate-My-LoRA.
  • Figure 2: (a) AUC per classification head architecture for each dataset. ResNet18 and NCC are included as a reference, (b) AUC per fine-tuning method. Error bars show the 95% CI obtained by bootstrapping on the test set.
  • Figure 3: Test AUC score per client with different federated aggregation methods. Error bars show the 95% CI (test set bootstrapping).
  • Figure 4: Intensity profiles per client after normalization, obtained using Kernel Density Estimation over all voxels within the brain mask from all scans.
  • Figure 5: Aggregation weights assigned to each client per round by (a) FedCE and (b) Rate-My-LoRA. FedCE produces stable weight trajectories due to its use of running aggregation which incorporates past weights into current updates, with NACC and ADNI consistently contributing the most. The weights of smaller clients (e.g. PND) gradually increase, suggesting that FedCE progressively integrates their contributions as training evolves. In contrast, Rate-My-LoRA shows highly variable weighting across rounds, with several clients intermittently receiving negligible weights. This reflects Rate-My-LoRA’s reliance on per-round validation performance, which can lead to unstable client weighting and potential underutilization of data from smaller sites.