Table of Contents
Fetching ...

Adaptive Self-Distillation for Minimizing Client Drift in Heterogeneous Federated Learning

M Yashwanth, Gaurav Kumar Nayak, Arya Singh, Yogesh Simmhan, Anirban Chakraborty

TL;DR

Federated Learning suffers from client drift when data are non-iid, especially with label distribution imbalance. The authors propose Adaptive Self-Distillation (ASD), a per-sample KL-divergence regularizer whose weights depend on the global model's prediction entropy and the client's label distribution, requiring no auxiliary data and easily adding to existing FL methods. They provide theoretical analysis showing reduced gradient dissimilarity and evidence of improved generalization via Hessian flatness, complemented by extensive experiments on CIFAR-10/100 and Tiny-Imagenet where ASD yields consistent gains across baselines. The work demonstrates a practical, low-overhead, plug-and-play approach to robust Federated Learning under label heterogeneity, improving convergence and performance without additional communication.

Abstract

Federated Learning (FL) is a machine learning paradigm that enables clients to jointly train a global model by aggregating the locally trained models without sharing any local training data. In practice, there can often be substantial heterogeneity (e.g., class imbalance) across the local data distributions observed by each of these clients. Under such non-iid label distributions across clients, FL suffers from the 'client-drift' problem where every client drifts to its own local optimum. This results in slower convergence and poor performance of the aggregated model. To address this limitation, we propose a novel regularization technique based on adaptive self-distillation (ASD) for training models on the client side. Our regularization scheme adaptively adjusts to each client's training data based on the global model's prediction entropy and the client-data label distribution. We show in this paper that our proposed regularization (ASD) can be easily integrated atop existing, state-of-the-art FL algorithms, leading to a further boost in the performance of these off-the-shelf methods. We theoretically explain how incorporation of ASD regularizer leads to reduction in client-drift and empirically justify the generalization ability of the trained model. We demonstrate the efficacy of our approach through extensive experiments on multiple real-world benchmarks and show substantial gains in performance when the proposed regularizer is combined with popular FL methods.

Adaptive Self-Distillation for Minimizing Client Drift in Heterogeneous Federated Learning

TL;DR

Federated Learning suffers from client drift when data are non-iid, especially with label distribution imbalance. The authors propose Adaptive Self-Distillation (ASD), a per-sample KL-divergence regularizer whose weights depend on the global model's prediction entropy and the client's label distribution, requiring no auxiliary data and easily adding to existing FL methods. They provide theoretical analysis showing reduced gradient dissimilarity and evidence of improved generalization via Hessian flatness, complemented by extensive experiments on CIFAR-10/100 and Tiny-Imagenet where ASD yields consistent gains across baselines. The work demonstrates a practical, low-overhead, plug-and-play approach to robust Federated Learning under label heterogeneity, improving convergence and performance without additional communication.

Abstract

Federated Learning (FL) is a machine learning paradigm that enables clients to jointly train a global model by aggregating the locally trained models without sharing any local training data. In practice, there can often be substantial heterogeneity (e.g., class imbalance) across the local data distributions observed by each of these clients. Under such non-iid label distributions across clients, FL suffers from the 'client-drift' problem where every client drifts to its own local optimum. This results in slower convergence and poor performance of the aggregated model. To address this limitation, we propose a novel regularization technique based on adaptive self-distillation (ASD) for training models on the client side. Our regularization scheme adaptively adjusts to each client's training data based on the global model's prediction entropy and the client-data label distribution. We show in this paper that our proposed regularization (ASD) can be easily integrated atop existing, state-of-the-art FL algorithms, leading to a further boost in the performance of these off-the-shelf methods. We theoretically explain how incorporation of ASD regularizer leads to reduction in client-drift and empirically justify the generalization ability of the trained model. We demonstrate the efficacy of our approach through extensive experiments on multiple real-world benchmarks and show substantial gains in performance when the proposed regularizer is combined with popular FL methods.
Paper Structure (36 sections, 9 theorems, 56 equations, 12 figures, 14 tables)

This paper contains 36 sections, 9 theorems, 56 equations, 12 figures, 14 tables.

Key Result

Proposition 3.1

$\inf_{\mathbf{w}\in \mathbb{R}^d} {G_d(\mathbf{w},\lambda)}$ is $1$, $\forall$$\lambda$

Figures (12)

  • Figure 1: Impact of one round of local training on the test accuracy of two clients with different label distribution sampled from CIFAR-10 dataset: The effect of local learning on test accuracy is analyzed by measuring the change in accuracy before and after local training, with positive values indicating improved model performance. Interestingly, in scenarios where classes with low probability of occurrence or under-represented, models trained using FedAvg frequently exhibit a decline in accuracy post-training. In contrast, incorporating our proposed adaptive self-distillation regularizer (ASD) into FedAvg (FedAvg+ASD) not only effectively captures knowledge from well-represented classes but also preserves information about under-represented classes. A similar pattern is observed with FedNTD and FedNTD+ASD.
  • Figure 2: Federated Learning with Adaptive Self-Distillation: The figure describes the overview of the proposed approach based on Adaptive distillation. In Step 1. The server broadcasts the model parameters, In Step 2. clients train their models by minimizing both the cross entropy loss and predicted probability distribution over the classes between the global model and the client model by minimizing the KL divergence, the importance of each sample in the batch is decided by the proposed adaptive scheme as a function of label distribution and the KL term. The server model is fixed while training the client. In Step 3. The server aggregates the client models based on FedAvg aggregation. The process repeats till convergence.
  • Figure 3: Eigen spectrum with and without the ASD regularizer. It is evident that ASD regularizer not only minimizes the top eigenvalue but most of the eigenvalues and attains the flatness.
  • Figure 4: Test Accuracy vs Communication rounds: Comparison of algorithms with $\delta = 0.3$ partitions on CIFAR-100 and Tiny-ImageNet datasets. All the algorithms augmented with proposed regularization (ASD) outperform compared to their original form. FedSpeed+ASD outperforms all the other algorithms.
  • Figure 5: Test Accuracy vs Communication rounds: Comparison of algorithms with $\delta = 0.6$ data partitions on CIFAR-100 and Tiny-ImageNet dataset. All the algorithms augmented with proposed regularization (ASD) outperform compared to their original form. FedSpeed+ASD outperforms all the other algorithms.
  • ...and 7 more figures

Theorems & Definitions (15)

  • Proposition 3.1
  • Proposition 3.2
  • Proposition 3.4
  • Proposition 3.6
  • proof
  • Proposition A.1
  • proof
  • Lemma A.2
  • proof
  • Proposition A.3
  • ...and 5 more