FedCFA: Alleviating Simpson's Paradox in Model Aggregation with Counterfactual Federated Learning
Zhonghua Jiang, Jimin Xu, Shengyu Zhang, Tao Shen, Jiwei Li, Kun Kuang, Haibin Cai, Fei Wu
TL;DR
FedCFA tackles Simpson's Paradox in federated learning by generating counterfactual local samples guided by a globally aggregated average dataset and learning more independent factors via a factor decorrelation loss. The framework integrates counterfactual transformations, contrastive learning, and a principled loss design to align local data distributions with the global, improving global accuracy and convergence under limited communication rounds. Extensive experiments across six datasets demonstrate that FedCFA outperforms strong baselines in both efficiency and accuracy in IID and non-IID settings, with ablations confirming the value of each component. This approach offers a practical, privacy-preserving strategy to robustly aggregate models in heterogeneous FL environments where Simpson's Paradox may otherwise degrade performance.
Abstract
Federated learning (FL) is a promising technology for data privacy and distributed optimization, but it suffers from data imbalance and heterogeneity among clients. Existing FL methods try to solve the problems by aligning client with server model or by correcting client model with control variables. These methods excel on IID and general Non-IID data but perform mediocrely in Simpson's Paradox scenarios. Simpson's Paradox refers to the phenomenon that the trend observed on the global dataset disappears or reverses on a subset, which may lead to the fact that global model obtained through aggregation in FL does not accurately reflect the distribution of global data. Thus, we propose FedCFA, a novel FL framework employing counterfactual learning to generate counterfactual samples by replacing local data critical factors with global average data, aligning local data distributions with the global and mitigating Simpson's Paradox effects. In addition, to improve the quality of counterfactual samples, we introduce factor decorrelation (FDC) loss to reduce the correlation among features and thus improve the independence of extracted factors. We conduct extensive experiments on six datasets and verify that our method outperforms other FL methods in terms of efficiency and global model accuracy under limited communication rounds.
