Table of Contents
Fetching ...

A Fast and Flat Federated Learning Method via Weighted Momentum and Sharpness-Aware Minimization

Tianle Li, Yongzhi Huang, Linshan Jiang, Chang Liu, Qipeng Xie, Wenfeng Du, Lu Wang, Kaishun Wu

TL;DR

This paper tackles the challenge of achieving fast convergence with flat minima in non-IID federated learning by integrating momentum with Sharpness-Aware Minimization (SAM). It identifies two failure modes—local-global curvature misalignment and momentum-echo—and introduces FedWMSAM, which uses a momentum-guided global perturbation and a cosine-similarity based adaptive coupling to realize an early-momentum, late-SAM schedule. A non-IID convergence bound is derived that explicitly accounts for perturbation-induced variance, showcasing robustness to data heterogeneity and system factors. Empirically, FedWMSAM demonstrates strong performance across multiple datasets and heterogeneity settings, achieving target accuracies in fewer rounds with near-FedAvg per-round cost, thereby delivering fast-and-flat optimization in realistic FL scenarios.

Abstract

In federated learning (FL), models must \emph{converge quickly} under tight communication budgets while \emph{generalizing} across non-IID client distributions. These twin requirements have naturally led to two widely used techniques: client/server \emph{momentum} to accelerate progress, and \emph{sharpness-aware minimization} (SAM) to prefer flat solutions. However, simply combining momentum and SAM leaves two structural issues unresolved in non-IID FL. We identify and formalize two failure modes: \emph{local-global curvature misalignment} (local SAM directions need not reflect the global loss geometry) and \emph{momentum-echo oscillation} (late-stage instability caused by accumulated momentum). To our knowledge, these failure modes have not been jointly articulated and addressed in the FL literature. We propose \textbf{FedWMSAM} to address both failure modes. First, we construct a momentum-guided global perturbation from server-aggregated momentum to align clients' SAM directions with the global descent geometry, enabling a \emph{single-backprop} SAM approximation that preserves efficiency. Second, we couple momentum and SAM via a cosine-similarity adaptive rule, yielding an early-momentum, late-SAM two-phase training schedule. We provide a non-IID convergence bound that \emph{explicitly models the perturbation-induced variance} $σ_ρ^2=σ^2+(Lρ)^2$ and its dependence on $(S, K, R, N)$ on the theory side. We conduct extensive experiments on multiple datasets and model architectures, and the results validate the effectiveness, adaptability, and robustness of our method, demonstrating its superiority in addressing the optimization challenges of Federated Learning. Our code is available at https://github.com/Huang-Yongzhi/NeurlPS_FedWMSAM.

A Fast and Flat Federated Learning Method via Weighted Momentum and Sharpness-Aware Minimization

TL;DR

This paper tackles the challenge of achieving fast convergence with flat minima in non-IID federated learning by integrating momentum with Sharpness-Aware Minimization (SAM). It identifies two failure modes—local-global curvature misalignment and momentum-echo—and introduces FedWMSAM, which uses a momentum-guided global perturbation and a cosine-similarity based adaptive coupling to realize an early-momentum, late-SAM schedule. A non-IID convergence bound is derived that explicitly accounts for perturbation-induced variance, showcasing robustness to data heterogeneity and system factors. Empirically, FedWMSAM demonstrates strong performance across multiple datasets and heterogeneity settings, achieving target accuracies in fewer rounds with near-FedAvg per-round cost, thereby delivering fast-and-flat optimization in realistic FL scenarios.

Abstract

In federated learning (FL), models must \emph{converge quickly} under tight communication budgets while \emph{generalizing} across non-IID client distributions. These twin requirements have naturally led to two widely used techniques: client/server \emph{momentum} to accelerate progress, and \emph{sharpness-aware minimization} (SAM) to prefer flat solutions. However, simply combining momentum and SAM leaves two structural issues unresolved in non-IID FL. We identify and formalize two failure modes: \emph{local-global curvature misalignment} (local SAM directions need not reflect the global loss geometry) and \emph{momentum-echo oscillation} (late-stage instability caused by accumulated momentum). To our knowledge, these failure modes have not been jointly articulated and addressed in the FL literature. We propose \textbf{FedWMSAM} to address both failure modes. First, we construct a momentum-guided global perturbation from server-aggregated momentum to align clients' SAM directions with the global descent geometry, enabling a \emph{single-backprop} SAM approximation that preserves efficiency. Second, we couple momentum and SAM via a cosine-similarity adaptive rule, yielding an early-momentum, late-SAM two-phase training schedule. We provide a non-IID convergence bound that \emph{explicitly models the perturbation-induced variance} and its dependence on on the theory side. We conduct extensive experiments on multiple datasets and model architectures, and the results validate the effectiveness, adaptability, and robustness of our method, demonstrating its superiority in addressing the optimization challenges of Federated Learning. Our code is available at https://github.com/Huang-Yongzhi/NeurlPS_FedWMSAM.

Paper Structure

This paper contains 47 sections, 8 theorems, 62 equations, 17 figures, 9 tables, 1 algorithm.

Key Result

Lemma 1

Under Assumption a1, if $\gamma L \leq \frac{1}{24}$, the following holds for all $r \geq 0$:

Figures (17)

  • Figure 1: The core of SAMs.
  • Figure 2: FedWMSAM idea: (a) personalized momentum reduces local-global discrepancy; (b) local model vs. momentum-based model difference guides global perturbation estimation; (c) a dynamic weighting adjusts momentum v.s. SAM based on gradient--momentum similarity.
  • Figure 3: Momentum vs SAM in FL.
  • Figure 4: Performance comparison on CIFAR-10 (left) and CIFAR-100 (right).
  • Figure 5: t-SNE visualization results of client embeddings using selected FL algorithms.
  • ...and 12 more figures

Theorems & Definitions (15)

  • Lemma 1
  • proof
  • Lemma 2: karimireddy2020scaffold
  • Lemma 3
  • proof
  • Lemma 4: Perturbation-Induced Gradient Variance
  • proof
  • Lemma 5
  • proof
  • Lemma 6
  • ...and 5 more