A Novel Algorithm for Personalized Federated Learning: Knowledge Distillation with Weighted Combination Loss
Hengrui Hu, Anai N. Kothari, Anjishnu Banerjee
TL;DR
The paper addresses non-IID data challenges in federated learning by introducing pFedKD-WCL, a personalized FL method that combines knowledge distillation with a weighted combination loss within a bi-level optimization framework. The global model serves as a teacher to guide local personalization, while local updates inform the global model through KL-based alignment, balancing convergence and personalization. Empirical results on MNIST and a synthetic non-IID dataset show that pFedKD-WCL achieves higher accuracy and faster convergence than FedAvg, FedProx, Per-FedAvg, and pFedMe, with performance dependent on model complexity and KD weight. This work provides a robust, privacy-preserving pathway for adapting federated models to heterogeneous client data, highlighting the need for adaptive KD weighting in more scalable deployments.
Abstract
Federated learning (FL) offers a privacy-preserving framework for distributed machine learning, enabling collaborative model training across diverse clients without centralizing sensitive data. However, statistical heterogeneity, characterized by non-independent and identically distributed (non-IID) client data, poses significant challenges, leading to model drift and poor generalization. This paper proposes a novel algorithm, pFedKD-WCL (Personalized Federated Knowledge Distillation with Weighted Combination Loss), which integrates knowledge distillation with bi-level optimization to address non-IID challenges. pFedKD-WCL leverages the current global model as a teacher to guide local models, optimizing both global convergence and local personalization efficiently. We evaluate pFedKD-WCL on the MNIST dataset and a synthetic dataset with non-IID partitioning, using multinomial logistic regression and multilayer perceptron models. Experimental results demonstrate that pFedKD-WCL outperforms state-of-the-art algorithms, including FedAvg, FedProx, Per-FedAvg, and pFedMe, in terms of accuracy and convergence speed.
