Table of Contents
Fetching ...

Taming Cross-Domain Representation Variance in Federated Prototype Learning with Heterogeneous Data Domains

Lei Wang, Jieming Bian, Letian Zhang, Chen Chen, Jie Xu

TL;DR

This work introduces FedPLVM, which establishes variance-aware dual-level prototypes clustering and employs a novel $\alpha$-sparsity prototype loss that aligns samples from underrepresented domains, enhancing intra-class similarity and reducing inter-class similarity.

Abstract

Federated learning (FL) allows collaborative machine learning training without sharing private data. While most FL methods assume identical data domains across clients, real-world scenarios often involve heterogeneous data domains. Federated Prototype Learning (FedPL) addresses this issue, using mean feature vectors as prototypes to enhance model generalization. However, existing FedPL methods create the same number of prototypes for each client, leading to cross-domain performance gaps and disparities for clients with varied data distributions. To mitigate cross-domain feature representation variance, we introduce FedPLVM, which establishes variance-aware dual-level prototypes clustering and employs a novel $α$-sparsity prototype loss. The dual-level prototypes clustering strategy creates local clustered prototypes based on private data features, then performs global prototypes clustering to reduce communication complexity and preserve local data privacy. The $α$-sparsity prototype loss aligns samples from underrepresented domains, enhancing intra-class similarity and reducing inter-class similarity. Evaluations on Digit-5, Office-10, and DomainNet datasets demonstrate our method's superiority over existing approaches.

Taming Cross-Domain Representation Variance in Federated Prototype Learning with Heterogeneous Data Domains

TL;DR

This work introduces FedPLVM, which establishes variance-aware dual-level prototypes clustering and employs a novel -sparsity prototype loss that aligns samples from underrepresented domains, enhancing intra-class similarity and reducing inter-class similarity.

Abstract

Federated learning (FL) allows collaborative machine learning training without sharing private data. While most FL methods assume identical data domains across clients, real-world scenarios often involve heterogeneous data domains. Federated Prototype Learning (FedPL) addresses this issue, using mean feature vectors as prototypes to enhance model generalization. However, existing FedPL methods create the same number of prototypes for each client, leading to cross-domain performance gaps and disparities for clients with varied data distributions. To mitigate cross-domain feature representation variance, we introduce FedPLVM, which establishes variance-aware dual-level prototypes clustering and employs a novel -sparsity prototype loss. The dual-level prototypes clustering strategy creates local clustered prototypes based on private data features, then performs global prototypes clustering to reduce communication complexity and preserve local data privacy. The -sparsity prototype loss aligns samples from underrepresented domains, enhancing intra-class similarity and reducing inter-class similarity. Evaluations on Digit-5, Office-10, and DomainNet datasets demonstrate our method's superiority over existing approaches.
Paper Structure (25 sections, 10 equations, 10 figures, 12 tables, 1 algorithm)

This paper contains 25 sections, 10 equations, 10 figures, 12 tables, 1 algorithm.

Figures (10)

  • Figure 1: Illustration of federated learning with heterogeneous data domains. The Vanilla column depicts the local feature distribution of the standard FedPL approach, obtaining average local and global prototypes directly. Proposed method showcased in the adjacent column yields a larger inter-class distance and a reduced intra-class distance. Note that without capturing variance information, even for hard domains, local averaged prototypes for each class can be well distinguished while the feature vectors are still mixed up. Both methods illustrate noticeable variations in domain characteristics across datasets, as detailed in Fig. \ref{['sec:proposed']}.
  • Figure 2: An overview of our proposed FedPLVM framework.Once the sample embedding is generated by the feature extractor, the client conducts the first-level local clustering, following Eq. \ref{['eq:lcp']}. Subsequently, the server gathers all local clustered prototypes and local models (comprising feature extractors and classifiers), initiates the second-level global clustering based on Eq. \ref{['eq:gpc']}, and averages the local models to form a global model. Finally, clients utilize the received global clustered prototypes to update the local model, employing loss functions $\mathcal{L}_\alpha$ from Eq. \ref{['eq:alpha']} and $\mathcal{L}_{CE}$ from Eq. \ref{['eq:ce']}.
  • Figure 3: Visualization of different prototype generation methods. The first row averages feature vectors locally and averages local prototypes globally. The second row averages feature vectors locally and clusters local prototypes globally. The last row (ours) clusters feature vectors locally and clusters local clustered prototypes globally. The last column Total is the visualization of mixing the feature vectors from all datasets. Details in Sec. \ref{['sec:cluster']}.
  • Figure 4: Impact of $\alpha$ sparsity and $\lambda$ prototype loss weight. The left figure shows the accuracy of two selected datasets and the average accuracy among all clients with different $\alpha$. The right figure shows the effects of different $\lambda$ for both FPL and our proposed approach. Details in Sec. \ref{['sec:loss']}.
  • Figure 5: Comparison on components of $\alpha$-sparsity prototype loss. Contrast and Correction stand for the contrastive and corrective loss term in the total $\alpha$-sparsity loss respectively. Avg is the average accuracy result for all clients. Details in Sec. \ref{['sec:loss']}.
  • ...and 5 more figures