FedPeWS: Personalized Warmup via Subnetworks for Enhanced Heterogeneous Federated Learning
Nurbek Tastan, Samuel Horvath, Martin Takac, Karthik Nandakumar
TL;DR
FedPeWS tackles extreme non-iid challenges in federated learning by introducing a personalized warmup that trains subnetworks via learnable neuron-level masks during early rounds, then switches to standard full-parameter federated optimization. The method identifies masks through gradient-based learning with a mask-generation function and a diversity term, and can operate in a fixed-subnetwork variant when data distributions are known. Across synthetic, image, and medical datasets, FedPeWS consistently reduces the number of communication rounds to reach target accuracy and improves final performance, with robustness to the hyperparameters governing mask learning and warmup length. This approach provides a practical, plug-and-play enhancement to existing FL optimizers, offering faster convergence and better generalization in highly heterogeneous cross-silo settings.
Abstract
Statistical data heterogeneity is a significant barrier to convergence in federated learning (FL). While prior work has advanced heterogeneous FL through better optimization objectives, these methods fall short when there is extreme data heterogeneity among collaborating participants. We hypothesize that convergence under extreme data heterogeneity is primarily hindered due to the aggregation of conflicting updates from the participants in the initial collaboration rounds. To overcome this problem, we propose a warmup phase where each participant learns a personalized mask and updates only a subnetwork of the full model. This personalized warmup allows the participants to focus initially on learning specific subnetworks tailored to the heterogeneity of their data. After the warmup phase, the participants revert to standard federated optimization, where all parameters are communicated. We empirically demonstrate that the proposed personalized warmup via subnetworks (FedPeWS) approach improves accuracy and convergence speed over standard federated optimization methods.
