Table of Contents
Fetching ...

Real World Federated Learning with a Knowledge Distilled Transformer for Cardiac CT Imaging

Malte Tölle, Philipp Garthe, Clemens Scherer, Jan Moritz Seliger, Andreas Leha, Nina Krüger, Stefan Simm, Simon Martin, Sebastian Eble, Halvar Kelm, Moritz Bednorz, Florian André, Peter Bannas, Gerhard Diller, Norbert Frey, Stefan Groß, Anja Hennemuth, Lars Kaderali, Alexander Meyer, Eike Nagel, Stefan Orwat, Moritz Seiffert, Tim Friede, Tim Seidler, Sandy Engelhardt

TL;DR

This work conducts the largest federated cardiac CT analysis to date in a real-world setting across eight hospitals, and outperforms UNet-based models in generalizability on downstream tasks.

Abstract

Federated learning is a renowned technique for utilizing decentralized data while preserving privacy. However, real-world applications often face challenges like partially labeled datasets, where only a few locations have certain expert annotations, leaving large portions of unlabeled data unused. Leveraging these could enhance transformer architectures ability in regimes with small and diversely annotated sets. We conduct the largest federated cardiac CT analysis to date (n=8,104) in a real-world setting across eight hospitals. Our two-step semi-supervised strategy distills knowledge from task-specific CNNs into a transformer. First, CNNs predict on unlabeled data per label type and then the transformer learns from these predictions with label-specific heads. This improves predictive accuracy and enables simultaneous learning of all partial labels across the federation, and outperforms UNet-based models in generalizability on downstream tasks. Code and model weights are made openly available for leveraging future cardiac CT analysis.

Real World Federated Learning with a Knowledge Distilled Transformer for Cardiac CT Imaging

TL;DR

This work conducts the largest federated cardiac CT analysis to date in a real-world setting across eight hospitals, and outperforms UNet-based models in generalizability on downstream tasks.

Abstract

Federated learning is a renowned technique for utilizing decentralized data while preserving privacy. However, real-world applications often face challenges like partially labeled datasets, where only a few locations have certain expert annotations, leaving large portions of unlabeled data unused. Leveraging these could enhance transformer architectures ability in regimes with small and diversely annotated sets. We conduct the largest federated cardiac CT analysis to date (n=8,104) in a real-world setting across eight hospitals. Our two-step semi-supervised strategy distills knowledge from task-specific CNNs into a transformer. First, CNNs predict on unlabeled data per label type and then the transformer learns from these predictions with label-specific heads. This improves predictive accuracy and enables simultaneous learning of all partial labels across the federation, and outperforms UNet-based models in generalizability on downstream tasks. Code and model weights are made openly available for leveraging future cardiac CT analysis.
Paper Structure (20 sections, 1 equation, 8 figures, 4 tables)

This paper contains 20 sections, 1 equation, 8 figures, 4 tables.

Figures (8)

  • Figure 1: Overview of federated consortium and federated knowledge distillation (KD) training pipeline. a) Federated learning procedure and b) our consortium across eight university hospitals in Germany. c) Each label subset is not present at all locations (Stage 1A). One model (UNet) is trained for each subset in a federated manner across the locations in possession of that label. d) Subsequently, the federated trained models are used to make predictions on the unlabeled data samples (Stage 1B). e) The transformer based- model is trained from the predictions of the teacher network with three heads but the same backbone (Stage 2AB). Last, only the heads are fintuned on the human annotated data samples. Naming is consistent with Figure \ref{['fig7:method_short']}.
  • Figure 1: Demographics of patients and data properties across locations. Some data was not available at all locations. Three manufacturers with in total eleven different models were included in the federated training. The acquisition protocols in terms of exposure, exposure time, X-ray tube current, and contrast bolus volume vary across locations. Manufacture acronyms are P: Philips, S: Siemens, T: Toshiba.
  • Figure 2: Comparison of UNets and transformer-based model (SWIN-UNETR) in boxplots for local, federated, and federated KD training for a) Hinge Points & Coronary Arteries (HPS & CAs), b) Memebranous Septum (MS), and c) Calcification. Test results on training clients are shown in blue, the results on independent test clients is shown in orange. In the boxplots median, 25th and 75th quartile, as well as outliers are shown. The locally trained models perform well on their locations's respective data, but do not generalize to the data from other locations. The transformer-based architecture performs worse than the Unet. The generalization performance can be enhanced with federated training, but the UNet still performs and generalizes better. After performing federated KD and subsequent finetuning the performance of the transformer-based model is on par with the UNet on detecting the hinge points, coronary ostia, and membranous septum, while outperforming it on segmenting the calcification. While the predictive performance of the SWIN-UNETR can be enhanced with more training samples due to KD to be better or on par with the UNet architecture, KD does not enhance the performance of the UNet to a similar degree.
  • Figure 3: Qualitative results of the predicted labels of FedKD SWIN-UNETR. The predictions of our final distilled model were inspected by two experienced cardiologists verifying that the points are placed within the anatomical variance present. RCC: right coronary cusp, LCC: left coronary cusp, NCC: non-coronary cusp, RCO: right coronary ostium, LCO: left coronary ostium, MS1: upper, and MS2: lower point of membranous septum, Myo: myocardium, LA: left atrium, LV: left ventricle, RA: right atrium, RV: right ventricle, PA: pulmonary artery.
  • Figure 4: Privacy-preserving inspection of labels. The overall distribution of landmarks should be similar across locations, because the geometrical relations between the points is relatively homogeneous. a) human annotated and b) model predicted hinge points, c) human annotated and d) model predicted membranous septum landmarks. In a) and b) the AA plane is defined from the three hinge points, the center point is registered, and the rotational angle is minimized to the distance from an optimal orientation of 120° between the three points. In c) and d) the RCC and NCC hinge points are registered and the location of the two points representing the membranous septum in relation to the two points is visualized. Thus, the overall quality of labels without disclosing any image information can be inspected. In c) MS1 and MS2 are confused (arrow points down). The spread is larger for the human annotated labels, which we attribute to slightly different annotation habits. RCC: right coronary cusp, LCC: left coronary cusp, NCC: non-coronary cusp, MS1: upper point of membranous septum, MS2: lower point of membranous septum.
  • ...and 3 more figures