Table of Contents
Fetching ...

FedNSAM:Consistency of Local and Global Flatness for Federated Learning

Junkang Liu, Fanhua Shang, Yuxuan Tian, Hongying Liu, Yuanyuan Liu

TL;DR

A novel FedNSAM algorithm is proposed that accelerates the SAM algorithm by introducing global Nesterov momentum into the local update to harmonize the consistency of global and local flatness.

Abstract

In federated learning (FL), multi-step local updates and data heterogeneity usually lead to sharper global minima, which degrades the performance of the global model. Popular FL algorithms integrate sharpness-aware minimization (SAM) into local training to address this issue. However, in the high data heterogeneity setting, the flatness in local training does not imply the flatness of the global model. Therefore, minimizing the sharpness of the local loss surfaces on the client data does not enable the effectiveness of SAM in FL to improve the generalization ability of the global model. We define the \textbf{flatness distance} to explain this phenomenon. By rethinking the SAM in FL and theoretically analyzing the \textbf{flatness distance}, we propose a novel \textbf{FedNSAM} algorithm that accelerates the SAM algorithm by introducing global Nesterov momentum into the local update to harmonize the consistency of global and local flatness. \textbf{FedNSAM} uses the global Nesterov momentum as the direction of local estimation of client global perturbations and extrapolation. Theoretically, we prove a tighter convergence bound than FedSAM by Nesterov extrapolation. Empirically, we conduct comprehensive experiments on CNN and Transformer models to verify the superior performance and efficiency of \textbf{FedNSAM}. The code is available at https://github.com/junkangLiu0/FedNSAM.

FedNSAM:Consistency of Local and Global Flatness for Federated Learning

TL;DR

A novel FedNSAM algorithm is proposed that accelerates the SAM algorithm by introducing global Nesterov momentum into the local update to harmonize the consistency of global and local flatness.

Abstract

In federated learning (FL), multi-step local updates and data heterogeneity usually lead to sharper global minima, which degrades the performance of the global model. Popular FL algorithms integrate sharpness-aware minimization (SAM) into local training to address this issue. However, in the high data heterogeneity setting, the flatness in local training does not imply the flatness of the global model. Therefore, minimizing the sharpness of the local loss surfaces on the client data does not enable the effectiveness of SAM in FL to improve the generalization ability of the global model. We define the \textbf{flatness distance} to explain this phenomenon. By rethinking the SAM in FL and theoretically analyzing the \textbf{flatness distance}, we propose a novel \textbf{FedNSAM} algorithm that accelerates the SAM algorithm by introducing global Nesterov momentum into the local update to harmonize the consistency of global and local flatness. \textbf{FedNSAM} uses the global Nesterov momentum as the direction of local estimation of client global perturbations and extrapolation. Theoretically, we prove a tighter convergence bound than FedSAM by Nesterov extrapolation. Empirically, we conduct comprehensive experiments on CNN and Transformer models to verify the superior performance and efficiency of \textbf{FedNSAM}. The code is available at https://github.com/junkangLiu0/FedNSAM.
Paper Structure (25 sections, 15 theorems, 34 equations, 7 figures, 6 tables, 2 algorithms)

This paper contains 25 sections, 15 theorems, 34 equations, 7 figures, 6 tables, 2 algorithms.

Key Result

Theorem 1

Suppose that local $\left\{F_i\right\}_{i=1}^N$ are non-convex and $L$-smooth. By setting $\eta \leq \frac{(1-\lambda)^2}{128 K L}$, $\rho=\sqrt{\frac{1}{T}}$, FedNSAM satisfies where $M_1^2:=\sigma^2+K(1-\frac{S}{N}) \sigma_g^2, M_2:=\frac{\sigma^2}{K}+\sigma_g^2$, and $F:=F(\boldsymbol{\theta}^0)-F(\boldsymbol{\theta}^{\star})$, $\left|S_t\right|=S$.

Figures (7)

  • Figure 1: (a) and (b) show the global training loss surface of FedSAM qu2022generalized under Dirichlet distributions with coefficients of 0.6 (low data heterogeneity) and 0.1 (high data heterogeneity) on CIFAR-100 with ResNet-18. (c) shows the training loss surface of our FedNSAM with Dirichlet-0.1. FedSAM can search global flat minima in low data heterogeneity but fails in high data heterogeneity. (d) suggests that the flat region of the client is closer in low data heterogeneity setting. The global model falls within the flat region. (e) suggests that the flat region of the client is far in high data heterogeneity setting. The global model cannot fall within the flat region. (f) suggests that our FedNSAM draws the flat region between each client closer by alignment correction so that the global model falls within the flat region of each client. FedNSAM's global model finds flat minima in (c).
  • Figure 2: Illustration of flatness distance (left) and global sharpness (right) during federated training on CIFAR100, using Dirichlet distributions with coefficients 0.1, 0.3, and 0.6 across 100 clients and a 10% participation rate. Test accuracies are 40.18%, 46.02%, and 47.83% for FedSAM with Dirichlet-0.1, 0.3, and 0.6 respectively, and 58.53% for FedNSAM with Dirichlet-0.1.
  • Figure 3: (a) and (b) depict the local update procedures of FedSAM and FedNSAM, respectively. FedNSAM incorporates global Nesterov momentum into local sharpness-aware updates to improve alignment between local and global flatness.
  • Figure 4: Convergence plots for FedNSAM and other baselines on Dirichlet-0.6 of CIFAR10 and CIFAR100 with ResNet-18.
  • Figure 5: Convergence plots for FedNSAM with different $\lambda$ and $\rho$, CIFAR100 datasets with ResNet-18.
  • ...and 2 more figures

Theorems & Definitions (15)

  • Theorem 1: Convergence for non-convex functions
  • Theorem 2
  • Theorem 3
  • Lemma 1
  • Lemma 2
  • Lemma 3
  • Lemma 4
  • Theorem 4: Convergence for non-convex functions
  • Lemma 5
  • Lemma 6
  • ...and 5 more