Table of Contents
Fetching ...

Tackling Data Heterogeneity in Federated Learning via Loss Decomposition

Shuang Zeng, Pengxin Guo, Shuai Wang, Jianbo Wang, Yuyin Zhou, Liangqiong Qu

TL;DR

The paper addresses data heterogeneity in cross-device Federated Learning, particularly for medical imaging, by introducing a global loss decomposition into Local loss, Distribution shift loss, and Aggregation loss. It then proposes FedLD, which jointly minimizes all three terms via margin control regularization in local training and a principal gradient-based server aggregation that projects and aggregates gradients along principal directions to reduce conflicts. The method is validated on retinal and chest X-ray classification tasks, where FedLD outperforms standard FL baselines across varying heterogeneity levels, with ablations demonstrating the individual contributions of margin control and principal-gradient aggregation. The work provides a practical, scalable framework that can be integrated into existing FL pipelines to improve robustness in heterogeneous data settings, with publicly available code at the provided repository.

Abstract

Federated Learning (FL) is a rising approach towards collaborative and privacy-preserving machine learning where large-scale medical datasets remain localized to each client. However, the issue of data heterogeneity among clients often compels local models to diverge, leading to suboptimal global models. To mitigate the impact of data heterogeneity on FL performance, we start with analyzing how FL training influence FL performance by decomposing the global loss into three terms: local loss, distribution shift loss and aggregation loss. Remarkably, our loss decomposition reveals that existing local training-based FL methods attempt to reduce the distribution shift loss, while the global aggregation-based FL methods propose better aggregation strategies to reduce the aggregation loss. Nevertheless, a comprehensive joint effort to minimize all three terms is currently limited in the literature, leading to subpar performance when dealing with data heterogeneity challenges. To fill this gap, we propose a novel FL method based on global loss decomposition, called FedLD, to jointly reduce these three loss terms. Our FedLD involves a margin control regularization in local training to reduce the distribution shift loss, and a principal gradient-based server aggregation strategy to reduce the aggregation loss. Notably, under different levels of data heterogeneity, our strategies achieve better and more robust performance on retinal and chest X-ray classification compared to other FL algorithms. Our code is available at https://github.com/Zeng-Shuang/FedLD.

Tackling Data Heterogeneity in Federated Learning via Loss Decomposition

TL;DR

The paper addresses data heterogeneity in cross-device Federated Learning, particularly for medical imaging, by introducing a global loss decomposition into Local loss, Distribution shift loss, and Aggregation loss. It then proposes FedLD, which jointly minimizes all three terms via margin control regularization in local training and a principal gradient-based server aggregation that projects and aggregates gradients along principal directions to reduce conflicts. The method is validated on retinal and chest X-ray classification tasks, where FedLD outperforms standard FL baselines across varying heterogeneity levels, with ablations demonstrating the individual contributions of margin control and principal-gradient aggregation. The work provides a practical, scalable framework that can be integrated into existing FL pipelines to improve robustness in heterogeneous data settings, with publicly available code at the provided repository.

Abstract

Federated Learning (FL) is a rising approach towards collaborative and privacy-preserving machine learning where large-scale medical datasets remain localized to each client. However, the issue of data heterogeneity among clients often compels local models to diverge, leading to suboptimal global models. To mitigate the impact of data heterogeneity on FL performance, we start with analyzing how FL training influence FL performance by decomposing the global loss into three terms: local loss, distribution shift loss and aggregation loss. Remarkably, our loss decomposition reveals that existing local training-based FL methods attempt to reduce the distribution shift loss, while the global aggregation-based FL methods propose better aggregation strategies to reduce the aggregation loss. Nevertheless, a comprehensive joint effort to minimize all three terms is currently limited in the literature, leading to subpar performance when dealing with data heterogeneity challenges. To fill this gap, we propose a novel FL method based on global loss decomposition, called FedLD, to jointly reduce these three loss terms. Our FedLD involves a margin control regularization in local training to reduce the distribution shift loss, and a principal gradient-based server aggregation strategy to reduce the aggregation loss. Notably, under different levels of data heterogeneity, our strategies achieve better and more robust performance on retinal and chest X-ray classification compared to other FL algorithms. Our code is available at https://github.com/Zeng-Shuang/FedLD.
Paper Structure (13 sections, 8 equations, 2 figures, 3 tables)

This paper contains 13 sections, 8 equations, 2 figures, 3 tables.

Figures (2)

  • Figure 1: Overview of the proposed FedLD. (1) Once each local client downloads the global model parameter $\boldsymbol{w}$, (2) it starts training locally with the cross-entropy loss and our proposed margin control regularization. (3) After that, each client uploads its local gradient to the global server. (4) Then, the global server aggregates these local gradients with our proposed principal gradient-based server aggregation, which includes three steps: First, use all local gradients to construct principal gradients; Second, calibrate principal gradients and use them to revise local gradients; Third, aggregate revised local gradients to generate the global gradient. (5) Finally, the server updates the global model parameter with the global gradient and sends it to local clients for the next round.
  • Figure 2: Distribution shift loss of different local training methods (red lines) and aggregation loss of different server aggregation methods (blue lines) under different levels of heterogeneity in one FL round.

Theorems & Definitions (1)

  • remark thmcounterremark