Table of Contents
Fetching ...

GFPL: Generative Federated Prototype Learning for Resource-Constrained and Data-Imbalanced Vision Task

Shiwei Lu, Yuhang He, Jiashuo Li, Qiang Wang, Yihong Gong

TL;DR

A novel Generative Federated Prototype Learning framework that improves model accuracy by 3.6% under imbalanced data settings while maintaining low communication cost, and devise a dual-classifier architecture, optimized via a hybrid loss combining Dot Regression and Cross-Entropy.

Abstract

Federated learning (FL) facilitates the secure utilization of decentralized images, advancing applications in medical image recognition and autonomous driving. However, conventional FL faces two critical challenges in real-world deployment: ineffective knowledge fusion caused by model updates biased toward majority-class features, and prohibitive communication overhead due to frequent transmissions of high-dimensional model parameters. Inspired by the human brain's efficiency in knowledge integration, we propose a novel Generative Federated Prototype Learning (GFPL) framework to address these issues. Within this framework, a prototype generation method based on Gaussian Mixture Model (GMM) captures the statistical information of class-wise features, while a prototype aggregation strategy using Bhattacharyya distance effectively fuses semantically similar knowledge across clients. In addition, these fused prototypes are leveraged to generate pseudo-features, thereby mitigating feature distribution imbalance across clients. To further enhance feature alignment during local training, we devise a dual-classifier architecture, optimized via a hybrid loss combining Dot Regression and Cross-Entropy. Extensive experiments on benchmarks show that GFPL improves model accuracy by 3.6% under imbalanced data settings while maintaining low communication cost.

GFPL: Generative Federated Prototype Learning for Resource-Constrained and Data-Imbalanced Vision Task

TL;DR

A novel Generative Federated Prototype Learning framework that improves model accuracy by 3.6% under imbalanced data settings while maintaining low communication cost, and devise a dual-classifier architecture, optimized via a hybrid loss combining Dot Regression and Cross-Entropy.

Abstract

Federated learning (FL) facilitates the secure utilization of decentralized images, advancing applications in medical image recognition and autonomous driving. However, conventional FL faces two critical challenges in real-world deployment: ineffective knowledge fusion caused by model updates biased toward majority-class features, and prohibitive communication overhead due to frequent transmissions of high-dimensional model parameters. Inspired by the human brain's efficiency in knowledge integration, we propose a novel Generative Federated Prototype Learning (GFPL) framework to address these issues. Within this framework, a prototype generation method based on Gaussian Mixture Model (GMM) captures the statistical information of class-wise features, while a prototype aggregation strategy using Bhattacharyya distance effectively fuses semantically similar knowledge across clients. In addition, these fused prototypes are leveraged to generate pseudo-features, thereby mitigating feature distribution imbalance across clients. To further enhance feature alignment during local training, we devise a dual-classifier architecture, optimized via a hybrid loss combining Dot Regression and Cross-Entropy. Extensive experiments on benchmarks show that GFPL improves model accuracy by 3.6% under imbalanced data settings while maintaining low communication cost.
Paper Structure (31 sections, 1 theorem, 28 equations, 7 figures, 3 tables, 2 algorithms)

This paper contains 31 sections, 1 theorem, 28 equations, 7 figures, 3 tables, 2 algorithms.

Key Result

Theorem A.1

Under Assumptions 1-4, if the learning rate $\eta$ satisfies $\eta \leq \frac{1}{L}$, then after $T$ communication rounds, the sequence of parameters $\{\Theta^t\}$ generated by the GFPL algorithm satisfies: where $C_1$ and $C_2$ are positive constants.

Figures (7)

  • Figure 1: Generative Federated Prototype Learning (GFPL):(a) the client trains feature extractor, projection layers through local data, while generating local prototypes with GMM and uploading them to the server for prototype interaction; (b) After receiving global prototypes, the client replaces local prototypes with global prototypes and further generates balanced pseudo features through GMM sampling, thereby retraining the model projection layer.
  • Figure 2: Visualization of features and projections under different loss functions: (a) only $\mathcal{L}_{DR}$ (features); (b) $\mathcal{L}_{DR}+\mathcal{L}_{CE}$ (features); (c) only $\mathcal{L}_{DR}$ (projections); (d) $\mathcal{L}_{DR}+\mathcal{L}_{CE}$ (projections)
  • Figure 3: The influence of hyperparameters on model performance: (a) $\lambda$; (b) GMM commponent number; (c)The average of shot.
  • Figure 4: The influence of hyperparameters on model performance: (a)$\hat{w}$; (b)Retraining interval; (c)Initial round of prototype interaction
  • Figure 5: Error bars of test accuracy on different datasets: (a-c) MNIST dataset with different $\bar{w}$ values; (d-f) FEMNIST dataset with different $\bar{w}$ values; (g-i) CIFAR10 dataset with different $\bar{w}$.
  • ...and 2 more figures

Theorems & Definitions (2)

  • Theorem A.1
  • proof