Table of Contents
Fetching ...

Distributionally Robust Alignment for Medical Federated Vision-Language Pre-training Under Data Heterogeneity

Zitao Shuai, Chenwei Wu, Zhengxu Tang, Liyue Shen

TL;DR

Federated Distributionally Robust Alignment (FedDRA) is proposed, a framework for federated VLP that achieves robust vision-language alignment under heterogeneous conditions and is effective in enhancing medical federated VLP under data heterogeneity.

Abstract

Vision-language pre-training (VLP) has emerged as an effective scheme for multimodal representation learning, but its reliance on large-scale multimodal data poses significant challenges for medical applications. Federated learning (FL) offers a promising solution to scale up the dataset for medical VLP while preserving data privacy. However, we observe that client data heterogeneity in real-world scenarios could cause models to learn biased cross-modal alignment during local pre-training. This would limit the transferability of the federally learned representation model on downstream tasks. To address this challenge, we propose Federated Distributionally Robust Alignment (FedDRA), a framework for federated VLP that achieves robust vision-language alignment under heterogeneous conditions. Based on client datasets, we construct a distribution family that encompasses potential test-time domains, and apply a distributionally robust framework to optimize the pre-trained model's performance across this distribution space. This approach bridges the gap between pre-training samples and downstream applications. To avoid over-fitting on client-specific information, we use anchor representation from the global model to guide the local training, and adopt a two-stage approach to first tune deeper layers before updating the entire network. Extensive experiments on real-world datasets demonstrate FedDRA's effectiveness in enhancing medical federated VLP under data heterogeneity. Our method also adapts well to various medical pre-training methods.

Distributionally Robust Alignment for Medical Federated Vision-Language Pre-training Under Data Heterogeneity

TL;DR

Federated Distributionally Robust Alignment (FedDRA) is proposed, a framework for federated VLP that achieves robust vision-language alignment under heterogeneous conditions and is effective in enhancing medical federated VLP under data heterogeneity.

Abstract

Vision-language pre-training (VLP) has emerged as an effective scheme for multimodal representation learning, but its reliance on large-scale multimodal data poses significant challenges for medical applications. Federated learning (FL) offers a promising solution to scale up the dataset for medical VLP while preserving data privacy. However, we observe that client data heterogeneity in real-world scenarios could cause models to learn biased cross-modal alignment during local pre-training. This would limit the transferability of the federally learned representation model on downstream tasks. To address this challenge, we propose Federated Distributionally Robust Alignment (FedDRA), a framework for federated VLP that achieves robust vision-language alignment under heterogeneous conditions. Based on client datasets, we construct a distribution family that encompasses potential test-time domains, and apply a distributionally robust framework to optimize the pre-trained model's performance across this distribution space. This approach bridges the gap between pre-training samples and downstream applications. To avoid over-fitting on client-specific information, we use anchor representation from the global model to guide the local training, and adopt a two-stage approach to first tune deeper layers before updating the entire network. Extensive experiments on real-world datasets demonstrate FedDRA's effectiveness in enhancing medical federated VLP under data heterogeneity. Our method also adapts well to various medical pre-training methods.
Paper Structure (27 sections, 2 theorems, 26 equations, 6 figures, 13 tables, 3 algorithms)

This paper contains 27 sections, 2 theorems, 26 equations, 6 figures, 13 tables, 3 algorithms.

Key Result

Proposition 1

Let $\{D_i, f_{\psi_i},f_{\phi_i}\}_{i=1}^N$ and $D_{\mathcal{T}}, f_{\psi_{\mathcal{T}}},f_{\phi_{\mathcal{T}}}$ be the distributions and optimal encoders for each client data domains and the testing domain, respectively. Given mixed weights $\{w_i\}_{i=1}^N$, $\sum_{i=1}^N w_i = 1$, $w_i \geq 0$,

Figures (6)

  • Figure 1: (1) We aim to tackle the data-hungry problem in VLP via federally utilizing private multi-modal paired data. Pre-training local models on heterogeneous client datasets may overfit the observed local data. (2). The naive method ignores the disparity between local distribution and potential testing distribution, and directly averaging local pre-trained models would obtain a less generalizable model. In contrast, our method optimizes model performance based on a family of potential testing distributions, dynamically given weights to schedule the local training, and obtain a more transferable model.
  • Figure 2: Our proposed FedDRA method follows a two-stage pre-training schema. Alignment modules are first trained in the first stage. Then we jointly update the whole model in the second stage. During local training, we use frozen copies of the server-aggregated models to obtain global representations $z^*_x$ and $z^*_y$, to regularize local training. After each round of local training, the server estimates the current worst-case distribution and updates the $\lambda_i$, which are then sent to each client to adjust the local update step-size.
  • Figure 3: (a). The comparison of retrieval acc. on each client denoted as $\{C_i\}_{i=1}^5$, of centralized, FedAvg, and averaged acc. of decentralized pre-trained models. (b). The performance of the server model after 25 com. turns and the averaged performance of corresponding client models after 25 and 26 com. turns, on each client denoted as $\{\text{Client}_i\}_{i=1}^5$. (c). The averaged acc. on each client. We show the acc. of centralized and FedAvg pre-trained baselines, and de-centralized pre-trained models shown as $\{C_i\}_{i=1}^5$ retrained on the union of training splits of client datasets.
  • Figure 4: (a) Analysis study on number of clients. (b) Analysis study on uncertainty radius. (c) Analysis study on global constraint degree. (d) Improvement of image-text embedding similarity per.commu turn. (e) Image-text embedding similarity curve of our method. (f) Averaged similarity curves.
  • Figure 5: Illustration of the strong connection between latent variable and the text modality.
  • ...and 1 more figures

Theorems & Definitions (4)

  • Proposition 1
  • Lemma 1
  • proof
  • proof