Dual-Distilled Heterogeneous Federated Learning with Adaptive Margins for Trainable Global Prototypes
Fatema Siddika, Md Anwar Hossen, Wensheng Zhang, Anuj Sharma, Juan Pablo Muñoz, Ali Jannesari
TL;DR
The paper tackles heterogeneity in Federated Learning by introducing FedProtoKD, a framework that combines dual knowledge distillation with adaptive, class-wise trainable prototypes to prevent prototype margin shrink during aggregation. It employs a learnable projection to align heterogeneous feature spaces, a contrastive generator to synthesize server prototypes with per-class adaptive margins, and a variance-weighted logit aggregation plus quality-aware public-data prioritization to distill robust global knowledge. The approach is validated on CIFAR-10/100 and Tiny-ImageNet across extreme and moderate non-IID settings, showing consistent improvements in server and client accuracy and demonstrating robustness to model and data heterogeneity as well as scalability factors. Overall, FedProtoKD advances prototype-based HFL by maintaining discriminative class boundaries and efficient cross-model knowledge transfer while preserving client privacy.
Abstract
Heterogeneous Federated Learning (HFL) has gained significant attention for its capacity to handle both model and data heterogeneity across clients. Prototype-based HFL methods emerge as a promising solution to address statistical and model heterogeneity as well as privacy challenges, paving the way for new advancements in HFL research. This method focuses on sharing class-representative prototypes among heterogeneous clients. However, aggregating these prototypes via standard weighted averaging often yields sub-optimal global knowledge. Specifically, the averaging approach induces a shrinking of the aggregated prototypes' decision margins, thereby degrading model performance in scenarios with model heterogeneity and non-IID data distributions. The propose FedProtoKD in a Heterogeneous Federated Learning setting, utilizing an enhanced dual-knowledge distillation mechanism to enhance system performance by leveraging clients' logits and prototype feature representations. The proposed framework aims to resolve the prototype margin-shrinking problem using a contrastive learning-based trainable server prototype by leveraging a class-wise adaptive prototype margin. Furthermore, the framework assess the importance of public samples using the closeness of the sample's prototype to its class representative prototypes, which enhances learning performance. FedProtoKD improved test accuracy by an average of 1.13% and up to 34.13% across various settings, significantly outperforming existing state-of-the-art HFL methods.
