Table of Contents
Fetching ...

Federated Learning over Connected Modes

Dennis Grinwald, Philipp Wiesner, Shinichi Nakajima

TL;DR

This work tackles statistical heterogeneity in federated learning by leveraging mode connectivity to learn a linearly connected low-loss solution simplex, enabling client personalization within a shared global structure. Floco assigns each client a subregion of the simplex based on gradient signals and jointly trains the endpoints to form a globally optimal simplex, improving both local and global performance with minimal overhead. An optional Floco+ extension adds Ditto-style local fine-tuning for further personalization. Across CIFAR-10 and FEMNIST benchmarks, Floco achieves superior accuracy and calibration, reduces worst-client performance gaps, and accelerates time-to-accuracy, demonstrating practical impact for cross-silo FL with heterogeneous data. The approach also offers insights into gradient variance reduction and scalable, communication-friendly personalization via a small, last-layer simplex mechanism.

Abstract

Statistical heterogeneity in federated learning poses two major challenges: slow global training due to conflicting gradient signals, and the need of personalization for local distributions. In this work, we tackle both challenges by leveraging recent advances in \emph{linear mode connectivity} -- identifying a linearly connected low-loss region in the parameter space of neural networks, which we call solution simplex. We propose federated learning over connected modes (\textsc{Floco}), where clients are assigned local subregions in this simplex based on their gradient signals, and together learn the shared global solution simplex. This allows personalization of the client models to fit their local distributions within the degrees of freedom in the solution simplex and homogenizes the update signals for the global simplex training. Our experiments show that \textsc{Floco} accelerates the global training process, and significantly improves the local accuracy with minimal computational overhead in cross-silo federated learning settings.

Federated Learning over Connected Modes

TL;DR

This work tackles statistical heterogeneity in federated learning by leveraging mode connectivity to learn a linearly connected low-loss solution simplex, enabling client personalization within a shared global structure. Floco assigns each client a subregion of the simplex based on gradient signals and jointly trains the endpoints to form a globally optimal simplex, improving both local and global performance with minimal overhead. An optional Floco+ extension adds Ditto-style local fine-tuning for further personalization. Across CIFAR-10 and FEMNIST benchmarks, Floco achieves superior accuracy and calibration, reduces worst-client performance gaps, and accelerates time-to-accuracy, demonstrating practical impact for cross-silo FL with heterogeneous data. The approach also offers insights into gradient variance reduction and scalable, communication-friendly personalization via a small, last-layer simplex mechanism.

Abstract

Statistical heterogeneity in federated learning poses two major challenges: slow global training due to conflicting gradient signals, and the need of personalization for local distributions. In this work, we tackle both challenges by leveraging recent advances in \emph{linear mode connectivity} -- identifying a linearly connected low-loss region in the parameter space of neural networks, which we call solution simplex. We propose federated learning over connected modes (\textsc{Floco}), where clients are assigned local subregions in this simplex based on their gradient signals, and together learn the shared global solution simplex. This allows personalization of the client models to fit their local distributions within the degrees of freedom in the solution simplex and homogenizes the update signals for the global simplex training. Our experiments show that \textsc{Floco} accelerates the global training process, and significantly improves the local accuracy with minimal computational overhead in cross-silo federated learning settings.
Paper Structure (34 sections, 15 equations, 5 figures, 6 tables, 1 algorithm)

This paper contains 34 sections, 15 equations, 5 figures, 6 tables, 1 algorithm.

Figures (5)

  • Figure 1: Floco expresses each client as a point ($\star$ in the top-center plot) by projecting the gradient signals onto the simplex, so that similar clients are close to each other. In each communication round, each client uniformly samples points in the neighborhood of their projected point (top-right plot), and jointly train the solution simplex. The lower row shows the resulting test loss on the solution simplex, where the loss for the global distribution (left) is uniformly small, while the losses for individual local distributions (center for client 1 and right for client 2) are small around their projected points.
  • Figure 2: Global (left) and average local (center) test accuracy for CifarCNN on CIFAR-10, 5-Fold. For Floco, we can clearly observe a jump in average local test accuracy at $\tau=250$, which is a result of our subregion assignment. Right shows the total variance of the gradients for the last fully-connected layer.
  • Figure 3: Global test accuracy.
  • Figure 4: Average local client test accuracy.
  • Figure 6: Local average client (left) and global (right) test accuracies for different subregion assignment time step $\tau$ and subregion radius $\rho$ settings.