Table of Contents
Fetching ...

PerAda: Parameter-Efficient Federated Learning Personalization with Generalization Guarantees

Chulin Xie, De-An Huang, Wenda Chu, Daguang Xu, Chaowei Xiao, Bo Li, Anima Anandkumar

TL;DR

Perada is proposed, a parameter-efficient pFL framework that reduces communication and computational costs and exhibits superior generalization performance under test-time distribution shifts and provides generalization bounds of Perada.

Abstract

Personalized Federated Learning (pFL) has emerged as a promising solution to tackle data heterogeneity across clients in FL. However, existing pFL methods either (1) introduce high communication and computation costs or (2) overfit to local data, which can be limited in scope, and are vulnerable to evolved test samples with natural shifts. In this paper, we propose PerAda, a parameter-efficient pFL framework that reduces communication and computational costs and exhibits superior generalization performance, especially under test-time distribution shifts. PerAda reduces the costs by leveraging the power of pretrained models and only updates and communicates a small number of additional parameters from adapters. PerAda has good generalization since it regularizes each client's personalized adapter with a global adapter, while the global adapter uses knowledge distillation to aggregate generalized information from all clients. Theoretically, we provide generalization bounds to explain why PerAda improves generalization, and we prove its convergence to stationary points under non-convex settings. Empirically, PerAda demonstrates competitive personalized performance (+4.85% on CheXpert) and enables better out-of-distribution generalization (+5.23% on CIFAR-10-C) on different datasets across natural and medical domains compared with baselines, while only updating 12.6% of parameters per model based on the adapter. Our code is available at https://github.com/NVlabs/PerAda.

PerAda: Parameter-Efficient Federated Learning Personalization with Generalization Guarantees

TL;DR

Perada is proposed, a parameter-efficient pFL framework that reduces communication and computational costs and exhibits superior generalization performance under test-time distribution shifts and provides generalization bounds of Perada.

Abstract

Personalized Federated Learning (pFL) has emerged as a promising solution to tackle data heterogeneity across clients in FL. However, existing pFL methods either (1) introduce high communication and computation costs or (2) overfit to local data, which can be limited in scope, and are vulnerable to evolved test samples with natural shifts. In this paper, we propose PerAda, a parameter-efficient pFL framework that reduces communication and computational costs and exhibits superior generalization performance, especially under test-time distribution shifts. PerAda reduces the costs by leveraging the power of pretrained models and only updates and communicates a small number of additional parameters from adapters. PerAda has good generalization since it regularizes each client's personalized adapter with a global adapter, while the global adapter uses knowledge distillation to aggregate generalized information from all clients. Theoretically, we provide generalization bounds to explain why PerAda improves generalization, and we prove its convergence to stationary points under non-convex settings. Empirically, PerAda demonstrates competitive personalized performance (+4.85% on CheXpert) and enables better out-of-distribution generalization (+5.23% on CIFAR-10-C) on different datasets across natural and medical domains compared with baselines, while only updating 12.6% of parameters per model based on the adapter. Our code is available at https://github.com/NVlabs/PerAda.
Paper Structure (51 sections, 25 theorems, 98 equations, 7 figures, 8 tables, 1 algorithm)

This paper contains 51 sections, 25 theorems, 98 equations, 7 figures, 8 tables, 1 algorithm.

Key Result

Theorem 1

Consider empirical datasets ${\mathbb{D}} \sim \mu, {\mathbb{D}}_{\mathtt{aux}} \sim \mu_{\mathtt{aux}}, {\mathbb{D}}_{m} \sim \mu_{m}$ with $|{\mathbb{D}}|=|{\mathbb{D}}_m|=n, |{\mathbb{D}}_{\mathtt{aux}}|=n_{\mathtt{aux}}$. Let $d_m$ be the VC dimension of ${\mathcal{H}}_m$, $\operatorname{Rad}_{n

Figures (7)

  • Figure 1: Accuracy of personalized models on Office-Home. "Full"/"Partial" denotes full/partial model personalization. PerAda achieves the highest personalized performance and generalization by updating the smallest number of model parameters.
  • Figure 2: Illustration of PerAda.
  • Figure 3: Current full model personalization incurs high computation costs by training two models, whereas existing partial model personalization often falls short in terms of generalizability. By updating adapter only, PerAda achieves a favorable balance between training/communication costs of clients and their pFL performance.
  • Figure 4: Effect of KD on PerAda evaluated on CIFAR-10. More distillation steps and data samples lead to better generalization and out-of-domain distillation data (STL-10, CIFAR-100) achieve similar performance as in-domain (validation) data.
  • Figure 5: Effect of different initializations (Random, FedAvg model, and ImageNet pretrained model).
  • ...and 2 more figures

Theorems & Definitions (48)

  • Theorem 1: Generalization bound of global model
  • Remark 1
  • Theorem 2: Generalization bound of personalized model
  • Remark 2
  • Theorem 3: Convergence of global model
  • Remark 3
  • Lemma 1: Empirical Rademacher complexity rademacher
  • Definition 1: Risk ben2010theory
  • Definition 2: ${\mathcal{H}}$ -divergence ben2010theory
  • Lemma 2: Domain adaptation ben2010theory
  • ...and 38 more