Table of Contents
Fetching ...

Prompt Estimation from Prototypes for Federated Prompt Tuning of Vision Transformers

M Yashwanth, Sharannya Ghosh, Aditay Tripathi, Anirban Chakraborty

TL;DR

This work tackles data heterogeneity in federated learning for Vision Transformers by proposing PEP-FedPT, which combines globally shared prompts with class-specific prompts to form a per-input Class-Contextualized Mixed Prompt (CCMP). CCMP uses per-class prototypes and client class priors to compute soft mixing weights, yielding per-sample personalization without client-specific trainable parameters. The authors prove that CCMP minimizes a quadratic upper bound on the loss and is MMSE-optimal for prompt estimation, while maintaining high efficiency in computation and communication. Empirical results across CIFAR-100, TinyImageNet, DomainNet, and iNaturalist demonstrate consistent improvements over state-of-the-art baselines in both label and feature heterogeneity settings and strong generalization to held-out clients. Overall, PEP-FedPT offers a practical, scalable solution for generalizable federated prompt tuning of Vision Transformers.

Abstract

Visual Prompt Tuning (VPT) of pre-trained Vision Transformers (ViTs) has proven highly effective as a parameter-efficient fine-tuning technique for adapting large models to downstream tasks with limited data. Its parameter efficiency makes it particularly suitable for Federated Learning (FL), where both communication and computation budgets are often constrained. However, global prompt tuning struggles to generalize across heterogeneous clients, while personalized tuning overfits to local data and lacks generalization. We propose PEP-FedPT (Prompt Estimation from Prototypes for Federated Prompt Tuning), a unified framework designed to achieve both generalization and personalization in federated prompt tuning of ViTs. Within this framework, we introduce the novel Class-Contextualized Mixed Prompt (CCMP) - based on class-specific prompts maintained alongside a globally shared prompt. For each input, CCMP adaptively combines class-specific prompts using weights derived from global class prototypes and client class priors. This approach enables per-sample prompt personalization without storing client-dependent trainable parameters. The prompts are collaboratively optimized via traditional federated averaging technique on the same. Comprehensive evaluations on CIFAR-100, TinyImageNet, DomainNet, and iNaturalist datasets demonstrate that PEP-FedPT consistently surpasses the state-of-the-art baselines under diverse data heterogeneity scenarios, establishing a strong foundation for efficient and generalizable federated prompt tuning of Vision Transformers.

Prompt Estimation from Prototypes for Federated Prompt Tuning of Vision Transformers

TL;DR

This work tackles data heterogeneity in federated learning for Vision Transformers by proposing PEP-FedPT, which combines globally shared prompts with class-specific prompts to form a per-input Class-Contextualized Mixed Prompt (CCMP). CCMP uses per-class prototypes and client class priors to compute soft mixing weights, yielding per-sample personalization without client-specific trainable parameters. The authors prove that CCMP minimizes a quadratic upper bound on the loss and is MMSE-optimal for prompt estimation, while maintaining high efficiency in computation and communication. Empirical results across CIFAR-100, TinyImageNet, DomainNet, and iNaturalist demonstrate consistent improvements over state-of-the-art baselines in both label and feature heterogeneity settings and strong generalization to held-out clients. Overall, PEP-FedPT offers a practical, scalable solution for generalizable federated prompt tuning of Vision Transformers.

Abstract

Visual Prompt Tuning (VPT) of pre-trained Vision Transformers (ViTs) has proven highly effective as a parameter-efficient fine-tuning technique for adapting large models to downstream tasks with limited data. Its parameter efficiency makes it particularly suitable for Federated Learning (FL), where both communication and computation budgets are often constrained. However, global prompt tuning struggles to generalize across heterogeneous clients, while personalized tuning overfits to local data and lacks generalization. We propose PEP-FedPT (Prompt Estimation from Prototypes for Federated Prompt Tuning), a unified framework designed to achieve both generalization and personalization in federated prompt tuning of ViTs. Within this framework, we introduce the novel Class-Contextualized Mixed Prompt (CCMP) - based on class-specific prompts maintained alongside a globally shared prompt. For each input, CCMP adaptively combines class-specific prompts using weights derived from global class prototypes and client class priors. This approach enables per-sample prompt personalization without storing client-dependent trainable parameters. The prompts are collaboratively optimized via traditional federated averaging technique on the same. Comprehensive evaluations on CIFAR-100, TinyImageNet, DomainNet, and iNaturalist datasets demonstrate that PEP-FedPT consistently surpasses the state-of-the-art baselines under diverse data heterogeneity scenarios, establishing a strong foundation for efficient and generalizable federated prompt tuning of Vision Transformers.

Paper Structure

This paper contains 47 sections, 6 theorems, 51 equations, 10 figures, 13 tables, 1 algorithm.

Key Result

Proposition 1

If the above assumptions hold, we show that $f$ can be upper bounded as $f \leq \tilde{L} = \frac{1}{n} \sum_{k=1,i=1}^{n,|C|} \delta^i_k \left( l_k^i({\mathbf{p}}_{c_i}) + \frac{\beta_{\max}}{2} \left\| \mathbf{m}(k) - \mathbf{p}_{c_i} \right\|^2 \right) + \tilde{C}$ and it is minimized at $\math

Figures (10)

  • Figure 1: The left panel illustrates server-client communication during federated training. In each communication round, clients insert shared prompts at the input of the transformer and class-contextualized prompts—derived by mixing class prompts using local class priors and global prototypes—at an intermediate layer.
  • Figure 2: The Top-5 accuracy computed based on the minimum distance between the cls token corresponding to the input and the cls prototypes. This shows that the cls representations in the middle layers have coarse information of the task.
  • Figure 3: Comparison of the convergence of different methods across the Communication rounds on the CIFAR-100 dataset with pathological non-iid partitioning where each client only observes $10$ classes.
  • Figure 4: Comparison of Non-IID Label Shift due to Pathological setting and the Dirichlet setting
  • Figure 5: Comparison of Non-IID Feature Shift
  • ...and 5 more figures

Theorems & Definitions (9)

  • Proposition 1
  • Proposition 2
  • Proposition 3
  • proof
  • Proposition \ref{prop:mmse}
  • proof
  • Proposition 3
  • Proposition 4
  • proof