Table of Contents
Fetching ...

Probabilistic Federated Prompt-Tuning with Non-IID and Imbalanced Data

Pei-Yau Weng, Minh Hoang, Lam M. Nguyen, My T. Thai, Tsui-Wei Weng, Trong Nghia Hoang

TL;DR

This work tackles federated learning with severely non-IID and imbalanced data by replacing costly full-model fine-tuning with a compact, probabilistic prompt-tuning strategy. By modeling local prompts as samples from a generative process anchored by global summarizing prompts, PFPT aligns diverse client contexts through a bipartite-matching-based aggregation that preserves privacy and reduces communication. The method achieves consistent gains over adapted FL baselines across multiple vision benchmarks, including long-tailed and globally skewed distributions, and demonstrates convergence and diversity in the learned prompts. Overall, PFPT offers a scalable, communication-efficient path to robust model adaptation in highly heterogeneous federated environments.

Abstract

Fine-tuning pre-trained models is a popular approach in machine learning for solving complex tasks with moderate data. However, fine-tuning the entire pre-trained model is ineffective in federated data scenarios where local data distributions are diversely skewed. To address this, we explore integrating federated learning with a more effective prompt-tuning method, optimizing for a small set of input prefixes to reprogram the pre-trained model's behavior. Our approach transforms federated learning into a distributed set modeling task, aggregating diverse sets of prompts to globally fine-tune the pre-trained model. We benchmark various baselines based on direct adaptations of existing federated model aggregation techniques and introduce a new probabilistic prompt aggregation method that substantially outperforms these baselines. Our reported results on a variety of computer vision datasets confirm that the proposed method is most effective to combat extreme data heterogeneity in federated learning.

Probabilistic Federated Prompt-Tuning with Non-IID and Imbalanced Data

TL;DR

This work tackles federated learning with severely non-IID and imbalanced data by replacing costly full-model fine-tuning with a compact, probabilistic prompt-tuning strategy. By modeling local prompts as samples from a generative process anchored by global summarizing prompts, PFPT aligns diverse client contexts through a bipartite-matching-based aggregation that preserves privacy and reduces communication. The method achieves consistent gains over adapted FL baselines across multiple vision benchmarks, including long-tailed and globally skewed distributions, and demonstrates convergence and diversity in the learned prompts. Overall, PFPT offers a scalable, communication-efficient path to robust model adaptation in highly heterogeneous federated environments.

Abstract

Fine-tuning pre-trained models is a popular approach in machine learning for solving complex tasks with moderate data. However, fine-tuning the entire pre-trained model is ineffective in federated data scenarios where local data distributions are diversely skewed. To address this, we explore integrating federated learning with a more effective prompt-tuning method, optimizing for a small set of input prefixes to reprogram the pre-trained model's behavior. Our approach transforms federated learning into a distributed set modeling task, aggregating diverse sets of prompts to globally fine-tune the pre-trained model. We benchmark various baselines based on direct adaptations of existing federated model aggregation techniques and introduce a new probabilistic prompt aggregation method that substantially outperforms these baselines. Our reported results on a variety of computer vision datasets confirm that the proposed method is most effective to combat extreme data heterogeneity in federated learning.

Paper Structure

This paper contains 22 sections, 3 theorems, 24 equations, 5 figures, 9 tables, 2 algorithms.

Key Result

Lemma 3.1

For any scalar function $g(\mathbf{r})$ and a binary vector $\boldsymbol{\zeta} = [\zeta_1, \zeta_2, \ldots, \zeta_n]$ such that $\zeta_i \in \{0,1\}$ and $\boldsymbol{\zeta}$ has at most one non-zero component, we have with respect to any set $\{\mathbf{r}_i\}_{i=1}^n$ of valid inputs to $g(\mathbf{r})$.

Figures (5)

  • Figure 1: Test Accuracy ($\%$) achieved on the CIFAR-10 dataset by solving Eq. \ref{['eq:FL_prompt']} via centralizing data, using FedAvg, and using FedProx on (orange) full-model (FM) and (blue) prompt-tuning (PT) setups. The evaluation is performed under (left) a standard (non-extreme) heterogeneous data partition; and (right) an extremely imbalanced data partitioning scheme (see Section \ref{['sec:exp']}).
  • Figure 2: Workflow of Probabilistic Federated Prompt Aggregation: (left) each client selects a subset of prompts from the global set of summarizing prompts using the prompt-selection mechanism adapted from wang2022learning, and fine-tune them using local data; and (right) the server collects all local prompt sets and updates the global summarizing prompts that aggregate similar local prompts. This is achieved by our proposed probabilistic federated prompt aggregation (PFPT) algorithm.
  • Figure 3: t-SNE plots of the (learned) summarizing prompts of PFPT on CIFAR-100 over $120$ communication iterations with different heterogeneity settings. Yellow triangles denote the centroids of the t-SNE embeddings of the prompts. The dashed red line visualizes their trajectories.
  • Figure 4: Variations in CIFAR-100 global prompt pool size across $120$ communication rounds under different heterogeneity settings.
  • Figure 5: t-SNE plots of the (learned) summarizing prompts of (top) our method PFPT and (bottom) GMM-PT over $120$ communication iterations with different heterogeneity settings. Yellow triangles denote the centroids of the t-SNE embeddings of the prompts. The dashed red lines visualize the trajectories of the centroids. The figure is best viewed with color.

Theorems & Definitions (3)

  • Lemma 3.1
  • Lemma 3.2
  • Lemma 3.3