Table of Contents
Fetching ...

Dynamic Regularized Sharpness Aware Minimization in Federated Learning: Approaching Global Consistency and Smooth Landscape

Yan Sun, Li Shen, Shixiang Chen, Liang Ding, Dacheng Tao

TL;DR

This work proposes a novel and general algorithm that adopts a dynamic regularizer to guarantee the local optima towards the global objective, which is meanwhile revised by the global Sharpness Aware Minimization (SAM) optimizer to search for the consistent flat minima.

Abstract

In federated learning (FL), a cluster of local clients are chaired under the coordination of the global server and cooperatively train one model with privacy protection. Due to the multiple local updates and the isolated non-iid dataset, clients are prone to overfit into their own optima, which extremely deviates from the global objective and significantly undermines the performance. Most previous works only focus on enhancing the consistency between the local and global objectives to alleviate this prejudicial client drifts from the perspective of the optimization view, whose performance would be prominently deteriorated on the high heterogeneity. In this work, we propose a novel and general algorithm {\ttfamily FedSMOO} by jointly considering the optimization and generalization targets to efficiently improve the performance in FL. Concretely, {\ttfamily FedSMOO} adopts a dynamic regularizer to guarantee the local optima towards the global objective, which is meanwhile revised by the global Sharpness Aware Minimization (SAM) optimizer to search for the consistent flat minima. Our theoretical analysis indicates that {\ttfamily FedSMOO} achieves fast $\mathcal{O}(1/T)$ convergence rate with low generalization bound. Extensive numerical studies are conducted on the real-world dataset to verify its peerless efficiency and excellent generality.

Dynamic Regularized Sharpness Aware Minimization in Federated Learning: Approaching Global Consistency and Smooth Landscape

TL;DR

This work proposes a novel and general algorithm that adopts a dynamic regularizer to guarantee the local optima towards the global objective, which is meanwhile revised by the global Sharpness Aware Minimization (SAM) optimizer to search for the consistent flat minima.

Abstract

In federated learning (FL), a cluster of local clients are chaired under the coordination of the global server and cooperatively train one model with privacy protection. Due to the multiple local updates and the isolated non-iid dataset, clients are prone to overfit into their own optima, which extremely deviates from the global objective and significantly undermines the performance. Most previous works only focus on enhancing the consistency between the local and global objectives to alleviate this prejudicial client drifts from the perspective of the optimization view, whose performance would be prominently deteriorated on the high heterogeneity. In this work, we propose a novel and general algorithm {\ttfamily FedSMOO} by jointly considering the optimization and generalization targets to efficiently improve the performance in FL. Concretely, {\ttfamily FedSMOO} adopts a dynamic regularizer to guarantee the local optima towards the global objective, which is meanwhile revised by the global Sharpness Aware Minimization (SAM) optimizer to search for the consistent flat minima. Our theoretical analysis indicates that {\ttfamily FedSMOO} achieves fast convergence rate with low generalization bound. Extensive numerical studies are conducted on the real-world dataset to verify its peerless efficiency and excellent generality.
Paper Structure (34 sections, 10 theorems, 51 equations, 10 figures, 8 tables, 1 algorithm)

This paper contains 34 sections, 10 theorems, 51 equations, 10 figures, 8 tables, 1 algorithm.

Key Result

Theorem 4.1

Let the assumptions hold, let the size of the active clients' set $\vert \left[\mathcal{n}\right]\vert=n$, and similarly, $\vert \left[\mathcal{m}\right]\vert=m$, let $r\leq \frac{4\kappa_r}{\sqrt{nT}}$ where $\kappa_r\in\mathbb{R}$ is a constant, and let $\beta\leq \frac{\sqrt{n}}{6\sqrt{6m}L}$, th where $\zeta\in\left(0,\frac{1}{2}\right)$ is a constant, $\kappa_f\triangleq f(\overline{w}^1)-f^\

Figures (10)

  • Figure 1: A toy schematic to introduce a bad case of FedSAM. We assume $m=2$ and $f=(f_1+f_2)/2$. The dotted lines represent the general loss surface trained by SGD, and the solid lines represent the flat loss surface trained by SAM. The red, green, and blue correspond to the local $f_1$, $f_2$, and global $f$, respectively.
  • Figure 2: Distribution across category on CIFAR-10 of sampling with$/$without replacement under the Dirichlet coefficient $u=0.1$ and the number of total clients $m=100$. The standard deviation of the samples' number approaches 829.14, which highly increases their imbalance and properly approximates the practical scenes.
  • Figure 3: Visualization of the loss landscape of ResNet-18 backbone trained via FedAvg, SCAFFOLD, MoFedSAM and FedSMOO on the CIFAR-10 dataset. For clarity, we use the grid surface on the FedSMOO and compare it with the other three baselines separately. FedSMOO could approach a more general and flat loss landscape which efficiently improves the generalization performance in FL.
  • Figure 4: Hyperparameters sensitivity studies of local intervals, learning rate decay, penalized coefficient $\beta$ and SAM-lr $r$ on CIFAR-10.
  • Figure 5: Heat-map of the Dirichlet split and Pathological split.
  • ...and 5 more figures

Theorems & Definitions (18)

  • Theorem 4.1
  • Remark 4.2
  • Remark 4.3
  • Theorem 4.4
  • Lemma 3.1
  • proof
  • Lemma 3.2
  • proof
  • Lemma 3.3
  • proof
  • ...and 8 more