Table of Contents
Fetching ...

FLUX: Efficient Descriptor-Driven Clustered Federated Learning under Arbitrary Distribution Shifts

Dario Fenoglio, Mohan Li, Pietro Barbiero, Nicholas D. Lane, Marc Langheinrich, Martin Gjoreski

TL;DR

Flux introduces a descriptor-driven clustered federated learning framework that simultaneously handles four common distribution shifts (P(X), P(Y), P(Y|X), P(X|Y)) without requiring prior knowledge of the number of clusters and enables test-time adaptation for unseen, unlabeled clients. It learns compact, privacy-preserving descriptors in the latent space, clusters clients with an unsupervised method, and trains cluster-specific models, achieving state-of-the-art robustness across six datasets and ten baselines. The approach maintains near-FedAvg overhead and scales to large client populations, with formal privacy considerations via optional differential privacy on descriptors. Overall, Flux provides a practical, scalable solution for robust FL under realistic non-IID conditions and test-time deployment scenarios.

Abstract

Federated Learning (FL) enables collaborative model training across multiple clients while preserving data privacy. Traditional FL methods often use a global model to fit all clients, assuming that clients' data are independent and identically distributed (IID). However, when this assumption does not hold, the global model accuracy may drop significantly, limiting FL applicability in real-world scenarios. To address this gap, we propose FLUX, a novel clustering-based FL (CFL) framework that addresses the four most common types of distribution shifts during both training and test time. To this end, FLUX leverages privacy-preserving client-side descriptor extraction and unsupervised clustering to ensure robust performance and scalability across varying levels and types of distribution shifts. Unlike existing CFL methods addressing non-IID client distribution shifts, FLUX i) does not require any prior knowledge of the types of distribution shifts or the number of client clusters, and ii) supports test-time adaptation, enabling unseen and unlabeled clients to benefit from the most suitable cluster-specific models. Extensive experiments across four standard benchmarks, two real-world datasets and ten state-of-the-art baselines show that FLUX improves performance and stability under diverse distribution shifts, achieving an average accuracy gain of up to 23 percentage points over the best-performing baselines, while maintaining computational and communication overhead comparable to FedAvg.

FLUX: Efficient Descriptor-Driven Clustered Federated Learning under Arbitrary Distribution Shifts

TL;DR

Flux introduces a descriptor-driven clustered federated learning framework that simultaneously handles four common distribution shifts (P(X), P(Y), P(Y|X), P(X|Y)) without requiring prior knowledge of the number of clusters and enables test-time adaptation for unseen, unlabeled clients. It learns compact, privacy-preserving descriptors in the latent space, clusters clients with an unsupervised method, and trains cluster-specific models, achieving state-of-the-art robustness across six datasets and ten baselines. The approach maintains near-FedAvg overhead and scales to large client populations, with formal privacy considerations via optional differential privacy on descriptors. Overall, Flux provides a practical, scalable solution for robust FL under realistic non-IID conditions and test-time deployment scenarios.

Abstract

Federated Learning (FL) enables collaborative model training across multiple clients while preserving data privacy. Traditional FL methods often use a global model to fit all clients, assuming that clients' data are independent and identically distributed (IID). However, when this assumption does not hold, the global model accuracy may drop significantly, limiting FL applicability in real-world scenarios. To address this gap, we propose FLUX, a novel clustering-based FL (CFL) framework that addresses the four most common types of distribution shifts during both training and test time. To this end, FLUX leverages privacy-preserving client-side descriptor extraction and unsupervised clustering to ensure robust performance and scalability across varying levels and types of distribution shifts. Unlike existing CFL methods addressing non-IID client distribution shifts, FLUX i) does not require any prior knowledge of the types of distribution shifts or the number of client clusters, and ii) supports test-time adaptation, enabling unseen and unlabeled clients to benefit from the most suitable cluster-specific models. Extensive experiments across four standard benchmarks, two real-world datasets and ten state-of-the-art baselines show that FLUX improves performance and stability under diverse distribution shifts, achieving an average accuracy gain of up to 23 percentage points over the best-performing baselines, while maintaining computational and communication overhead comparable to FedAvg.

Paper Structure

This paper contains 63 sections, 1 theorem, 22 equations, 12 figures, 66 tables, 2 algorithms.

Key Result

Proposition C.1

Consider the distance between two client descriptors defined as Under assumptions (A1)–(A3), there exist constants such that the following inequality holds: Hence, within the admissible covariance set, the squared 2-Wasserstein distance and the descriptor distance are equivalent up to constant factors.

Figures (12)

  • Figure 1: Types of data distribution shifts.(a) Feature distribution shift: two subsets differ in feature distributions, while label distributions are similar. (b) Label distribution shift: two subsets differ in label distributions, while feature distributions (for each class) are similar. (c)$P(Y|X)$ concept shift: two subsets share the same feature distributions but differ in label distributions. (d)$P(X|Y)$ concept shift: two subsets share the same label distributions but differ in feature distributions.
  • Figure 2: Flux's PGM.Solid line: data generating mechanism. Dashed line: inference direction. Gray line: present only during training.
  • Figure 3: Overview of the Flux framework for efficient unsupervised CFL.Flux operates without prior knowledge of client data, handling distribution shifts and optimizing cluster-specific model assignment to unseen, unlabeled clients at inference.
  • Figure 4: Mean accuracy and standard deviation across heterogeneity levels for MNIST.
  • Figure 5: Mean accuracy and standard deviation on MNIST dataset with varying numbers of clients. Left: known association condition, where test-time cluster associations are available. Middle: test phase condition, where cluster associations are inferred. Right: training time per 10 rounds.
  • ...and 7 more figures

Theorems & Definitions (2)

  • Proposition C.1: Lipschitz-equivalence to $W_2$ for marginals
  • proof