Table of Contents
Fetching ...

FUNAvg: Federated Uncertainty Weighted Averaging for Datasets with Diverse Labels

Malte Tölle, Fernando Navarro, Sebastian Eble, Ivo Wolf, Bjoern Menze, Sandy Engelhardt

TL;DR

The paper tackles learning from partially annotated, privacy-preserving medical imaging datasets by federating a shared backbone while giving each site its own segmentation head. It introduces FUNAvg, which enforces uncertainty-weighted averaging of head predictions, leveraging MC dropout-derived uncertainty to reveal and utilize unannotated structures, thus improving predictions for underrepresented labels. The approach achieves Dice scores comparable to dataset-specific models and outperforms centralized baselines when incorporating uncertainty, with improved calibration. The method enables effective multi-dataset segmentation under privacy constraints and heterogeneous annotation protocols, offering practical benefits for medical image analysis and potential applicability to other domains.

Abstract

Federated learning is one popular paradigm to train a joint model in a distributed, privacy-preserving environment. But partial annotations pose an obstacle meaning that categories of labels are heterogeneous over clients. We propose to learn a joint backbone in a federated manner, while each site receives its own multi-label segmentation head. By using Bayesian techniques we observe that the different segmentation heads although only trained on the individual client's labels also learn information about the other labels not present at the respective site. This information is encoded in their predictive uncertainty. To obtain a final prediction we leverage this uncertainty and perform a weighted averaging of the ensemble of distributed segmentation heads, which allows us to segment "locally unknown" structures. With our method, which we refer to as FUNAvg, we are even on-par with the models trained and tested on the same dataset on average. The code is publicly available at https://github.com/Cardio-AI/FUNAvg.

FUNAvg: Federated Uncertainty Weighted Averaging for Datasets with Diverse Labels

TL;DR

The paper tackles learning from partially annotated, privacy-preserving medical imaging datasets by federating a shared backbone while giving each site its own segmentation head. It introduces FUNAvg, which enforces uncertainty-weighted averaging of head predictions, leveraging MC dropout-derived uncertainty to reveal and utilize unannotated structures, thus improving predictions for underrepresented labels. The approach achieves Dice scores comparable to dataset-specific models and outperforms centralized baselines when incorporating uncertainty, with improved calibration. The method enables effective multi-dataset segmentation under privacy constraints and heterogeneous annotation protocols, offering practical benefits for medical image analysis and potential applicability to other domains.

Abstract

Federated learning is one popular paradigm to train a joint model in a distributed, privacy-preserving environment. But partial annotations pose an obstacle meaning that categories of labels are heterogeneous over clients. We propose to learn a joint backbone in a federated manner, while each site receives its own multi-label segmentation head. By using Bayesian techniques we observe that the different segmentation heads although only trained on the individual client's labels also learn information about the other labels not present at the respective site. This information is encoded in their predictive uncertainty. To obtain a final prediction we leverage this uncertainty and perform a weighted averaging of the ensemble of distributed segmentation heads, which allows us to segment "locally unknown" structures. With our method, which we refer to as FUNAvg, we are even on-par with the models trained and tested on the same dataset on average. The code is publicly available at https://github.com/Cardio-AI/FUNAvg.
Paper Structure (12 sections, 3 equations, 5 figures, 3 tables)

This paper contains 12 sections, 3 equations, 5 figures, 3 tables.

Figures (5)

  • Figure 1: Our proposed training and inference scheme. During training each site optimizes its own segmentation head according to the number of present labels at the respective site. On the central server only the backbone is averaged in a federated fashion. During inference all segmentation heads are gathered and an average of the softmax probabilities is computed weighted by the number of sites the individual label was present. By utilizing the predictive uncertainty of the classifiers with FUNAvg the predictions can be improved, which is especially benefitial for underrepresented labels across the federated sites.
  • Figure 2: Proposed federated uncertainty weighted averaging for (a) one pixel and (b) an entire image. (a) When averaging the softmax probabilities of sites 1-3 the final prediction would be "Background" (Bg). By reweighting the probability for background by the uncertainty (U) we obtain the right label of "Spleen" in this example. (b) After averaging the logits the final prediction $\hat{y}$ is fragmentary especially in the area of the lung in above example, while the lung is perfectly visible in the uncertainty estimation $\hat{u}$. We therefore multiply the background channel with the inverse of the uncertainty $\hat{y}_b \times (1-\hat{u})$ to obtain $\hat{y}_u$.
  • Figure 3: Datasets used for training and testing and their respective label distribution. They differ in general quantity of training samples as well as number of annotated labels present. We used the following open data: Liver Tumor Segmentation (LiTS) bilic2023lits, Beyond the Cranial Vault (BCV, Cervix and Abdomen) landman2015bcv, Combined Healthy Abdominal Organ (CHAOS) kavur2019chaos, Learn2Reg zhoubing2016l2r, AbdomenCT-1k ma2022AbCT1k, Abdominal Multi-Organ Benchmark (AMOS) yuanfeng2022amos, and TotalSegmentator wasserthal2022totalsegmentator, and two in-house datasets termed VisceralGC and SC.
  • Figure 4: Calibration of the different methods in terms of Expected Calibration Error (a) and performance gain of FUNAvg in comparison to vanilla averaging the different logits from the federated training segmentation heads (b).
  • Figure 6: One representative slide from each dataset used in the federation with corresponding labels. Additional to the different annotated structures in each dataset, the field of view and with that the intensity distributions vary across datasets.