Table of Contents
Fetching ...

Unlocking the Potential of Prompt-Tuning in Bridging Generalized and Personalized Federated Learning

Wenlong Deng, Christos Thrampoulidis, Xiaoxiao Li

TL;DR

The paper tackles federated learning with vision transformers under data heterogeneity by proposing SGPT, a prompt-tuning-based framework that combines shared prompts for universal knowledge with group prompts for local specialization. A prompt selection module assigns inputs to data groups, enabling sample-level adaptation without local fine-tuning, while a block coordinate descent optimization alternates between learning shared information and group-specific knowledge. The authors provide a theoretical bound on the global-local performance gap in terms of generalization and distribution discrepancy, and empirically validate SGPT on label- and feature-heterogeneous benchmarks, showing superior global and worst-local performance with improved efficiency. This approach offers a practical, scalable path to robust FL with ViT by leveraging prompt-tuning to navigate cross-client heterogeneity and reduce communication and computation costs.

Abstract

Vision Transformers (ViT) and Visual Prompt Tuning (VPT) achieve state-of-the-art performance with improved efficiency in various computer vision tasks. This suggests a promising paradigm shift of adapting pre-trained ViT models to Federated Learning (FL) settings. However, the challenge of data heterogeneity among FL clients presents a significant hurdle in effectively deploying ViT models. Existing Generalized FL (GFL) and Personalized FL (PFL) methods have limitations in balancing performance across both global and local data distributions. In this paper, we present a novel algorithm, SGPT, that integrates GFL and PFL approaches by employing a unique combination of both shared and group-specific prompts. This design enables SGPT to capture both common and group-specific features. A key feature of SGPT is its prompt selection module, which facilitates the training of a single global model capable of automatically adapting to diverse local client data distributions without the need for local fine-tuning. To effectively train the prompts, we utilize block coordinate descent (BCD), learning from common feature information (shared prompts), and then more specialized knowledge (group prompts) iteratively. Theoretically, we justify that learning the proposed prompts can reduce the gap between global and local performance. Empirically, we conduct experiments on both label and feature heterogeneity settings in comparison with state-of-the-art baselines, along with extensive ablation studies, to substantiate the superior performance of SGPT.

Unlocking the Potential of Prompt-Tuning in Bridging Generalized and Personalized Federated Learning

TL;DR

The paper tackles federated learning with vision transformers under data heterogeneity by proposing SGPT, a prompt-tuning-based framework that combines shared prompts for universal knowledge with group prompts for local specialization. A prompt selection module assigns inputs to data groups, enabling sample-level adaptation without local fine-tuning, while a block coordinate descent optimization alternates between learning shared information and group-specific knowledge. The authors provide a theoretical bound on the global-local performance gap in terms of generalization and distribution discrepancy, and empirically validate SGPT on label- and feature-heterogeneous benchmarks, showing superior global and worst-local performance with improved efficiency. This approach offers a practical, scalable path to robust FL with ViT by leveraging prompt-tuning to navigate cross-client heterogeneity and reduce communication and computation costs.

Abstract

Vision Transformers (ViT) and Visual Prompt Tuning (VPT) achieve state-of-the-art performance with improved efficiency in various computer vision tasks. This suggests a promising paradigm shift of adapting pre-trained ViT models to Federated Learning (FL) settings. However, the challenge of data heterogeneity among FL clients presents a significant hurdle in effectively deploying ViT models. Existing Generalized FL (GFL) and Personalized FL (PFL) methods have limitations in balancing performance across both global and local data distributions. In this paper, we present a novel algorithm, SGPT, that integrates GFL and PFL approaches by employing a unique combination of both shared and group-specific prompts. This design enables SGPT to capture both common and group-specific features. A key feature of SGPT is its prompt selection module, which facilitates the training of a single global model capable of automatically adapting to diverse local client data distributions without the need for local fine-tuning. To effectively train the prompts, we utilize block coordinate descent (BCD), learning from common feature information (shared prompts), and then more specialized knowledge (group prompts) iteratively. Theoretically, we justify that learning the proposed prompts can reduce the gap between global and local performance. Empirically, we conduct experiments on both label and feature heterogeneity settings in comparison with state-of-the-art baselines, along with extensive ablation studies, to substantiate the superior performance of SGPT.
Paper Structure (37 sections, 4 theorems, 35 equations, 10 figures, 6 tables, 3 algorithms)

This paper contains 37 sections, 4 theorems, 35 equations, 10 figures, 6 tables, 3 algorithms.

Key Result

Theorem 4.1

Assume the loss function $\ell$ is bounded in $[0,1]$ and the function Select is a data grouping method. Let the VC-dimension of hypothesis class $\mathcal{H}$ be $d$. Then, with a probability of at least $1-\delta$ over the training set, where $\operatorname{disc}_{\mathcal{H}}\left(\mathcal{D}_1, \mathcal{D}_2\right)=\max _{h \in \mathcal{H}}\left|\mathcal{L}_{\mathcal{D}_1}(h)-\mathcal{L}_{\ma

Figures (10)

  • Figure 1: Global accuracy and worst local accuracy on CIFAR-100 with $s=10$ ($s$ is the number of classes per client). Points located in the top-right corner correspond to great performance on both the global data and local clients’ data distributions. PFL models perform well on local data, however, lack the ability to predict out-of-client data. Global models have a better generalization but cannot well adapt to each local data distribution. Our proposed SGPT ($\star$) achieves the best trade-off.
  • Figure 2: Pipleline of our method. (a) Provides an overview of the federated group-aware prompt-tuning SGPT procedure. Each model comprises shared prompts and group prompts, facilitating the acquisition of both common and group-specific knowledge. The shared prompts and classification head are globally trained, while the group prompt is inserted into intermediate layers trained within its respective data group and shared globally. (b) Depicts the prompt selection module. Here, each input undergoes processing by a pre-trained ViT model encoder. Similarities between keys and last layer CLS token features are calculated, and the prompt corresponding to the most similar key is selected for training, enabling group-aware training at the sample level. (c) Given that data distributions vary across clients, the frequencies of selected group prompts differ, ensuring our model aligns with various local data distributions.
  • Figure 3: Stability analysis of one example client on CIFAR-100 dataset with $s=10$. We plot the mean and standard deviation of the prompt selection number overall communication rounds. (a) Without stability regularization, the variance is larger and is unstable. (b) With our proposed momentum updating, the variance is reduced and is more stable.
  • Figure 4: Exploring prompt insertion layers on CIFAR100 ($s$=10). The brown curve represents performance using only shared prompts, while the blue curve illustrates performance with group prompts inserted at varying layers, alongside shared prompts in layer 3.
  • Figure 5: T-SNE maps of CIFAR-100 data features processed by the different layers of the ImageNet-21K pre-trained ViT-16/B model. Data from different coarse classes are labeled with different colors.
  • ...and 5 more figures

Theorems & Definitions (8)

  • Theorem 4.1: Gap between the global and local performance
  • Lemma C.1: Split the distribution risk
  • proof
  • Lemma C.2: Bound on the generalization error
  • proof
  • proof
  • Theorem C.3
  • proof