Tackling Data Heterogeneity in Federated Learning via Loss Decomposition
Shuang Zeng, Pengxin Guo, Shuai Wang, Jianbo Wang, Yuyin Zhou, Liangqiong Qu
TL;DR
The paper addresses data heterogeneity in cross-device Federated Learning, particularly for medical imaging, by introducing a global loss decomposition into Local loss, Distribution shift loss, and Aggregation loss. It then proposes FedLD, which jointly minimizes all three terms via margin control regularization in local training and a principal gradient-based server aggregation that projects and aggregates gradients along principal directions to reduce conflicts. The method is validated on retinal and chest X-ray classification tasks, where FedLD outperforms standard FL baselines across varying heterogeneity levels, with ablations demonstrating the individual contributions of margin control and principal-gradient aggregation. The work provides a practical, scalable framework that can be integrated into existing FL pipelines to improve robustness in heterogeneous data settings, with publicly available code at the provided repository.
Abstract
Federated Learning (FL) is a rising approach towards collaborative and privacy-preserving machine learning where large-scale medical datasets remain localized to each client. However, the issue of data heterogeneity among clients often compels local models to diverge, leading to suboptimal global models. To mitigate the impact of data heterogeneity on FL performance, we start with analyzing how FL training influence FL performance by decomposing the global loss into three terms: local loss, distribution shift loss and aggregation loss. Remarkably, our loss decomposition reveals that existing local training-based FL methods attempt to reduce the distribution shift loss, while the global aggregation-based FL methods propose better aggregation strategies to reduce the aggregation loss. Nevertheless, a comprehensive joint effort to minimize all three terms is currently limited in the literature, leading to subpar performance when dealing with data heterogeneity challenges. To fill this gap, we propose a novel FL method based on global loss decomposition, called FedLD, to jointly reduce these three loss terms. Our FedLD involves a margin control regularization in local training to reduce the distribution shift loss, and a principal gradient-based server aggregation strategy to reduce the aggregation loss. Notably, under different levels of data heterogeneity, our strategies achieve better and more robust performance on retinal and chest X-ray classification compared to other FL algorithms. Our code is available at https://github.com/Zeng-Shuang/FedLD.
