Table of Contents
Fetching ...

Friendly Sharpness-Aware Minimization

Tao Li, Pan Zhou, Zhengbao He, Xinwen Cheng, Xiaolin Huang

TL;DR

F-SAM aims to mitigate the negative effects of the full gradient component of SAM, which removes the full gradient estimated by an exponentially moving average of historical stochastic gradients, and then leverages stochastic gradient noise for improved generalization.

Abstract

Sharpness-Aware Minimization (SAM) has been instrumental in improving deep neural network training by minimizing both training loss and loss sharpness. Despite the practical success, the mechanisms behind SAM's generalization enhancements remain elusive, limiting its progress in deep learning optimization. In this work, we investigate SAM's core components for generalization improvement and introduce "Friendly-SAM" (F-SAM) to further enhance SAM's generalization. Our investigation reveals the key role of batch-specific stochastic gradient noise within the adversarial perturbation, i.e., the current minibatch gradient, which significantly influences SAM's generalization performance. By decomposing the adversarial perturbation in SAM into full gradient and stochastic gradient noise components, we discover that relying solely on the full gradient component degrades generalization while excluding it leads to improved performance. The possible reason lies in the full gradient component's increase in sharpness loss for the entire dataset, creating inconsistencies with the subsequent sharpness minimization step solely on the current minibatch data. Inspired by these insights, F-SAM aims to mitigate the negative effects of the full gradient component. It removes the full gradient estimated by an exponentially moving average (EMA) of historical stochastic gradients, and then leverages stochastic gradient noise for improved generalization. Moreover, we provide theoretical validation for the EMA approximation and prove the convergence of F-SAM on non-convex problems. Extensive experiments demonstrate the superior generalization performance and robustness of F-SAM over vanilla SAM. Code is available at https://github.com/nblt/F-SAM.

Friendly Sharpness-Aware Minimization

TL;DR

F-SAM aims to mitigate the negative effects of the full gradient component of SAM, which removes the full gradient estimated by an exponentially moving average of historical stochastic gradients, and then leverages stochastic gradient noise for improved generalization.

Abstract

Sharpness-Aware Minimization (SAM) has been instrumental in improving deep neural network training by minimizing both training loss and loss sharpness. Despite the practical success, the mechanisms behind SAM's generalization enhancements remain elusive, limiting its progress in deep learning optimization. In this work, we investigate SAM's core components for generalization improvement and introduce "Friendly-SAM" (F-SAM) to further enhance SAM's generalization. Our investigation reveals the key role of batch-specific stochastic gradient noise within the adversarial perturbation, i.e., the current minibatch gradient, which significantly influences SAM's generalization performance. By decomposing the adversarial perturbation in SAM into full gradient and stochastic gradient noise components, we discover that relying solely on the full gradient component degrades generalization while excluding it leads to improved performance. The possible reason lies in the full gradient component's increase in sharpness loss for the entire dataset, creating inconsistencies with the subsequent sharpness minimization step solely on the current minibatch data. Inspired by these insights, F-SAM aims to mitigate the negative effects of the full gradient component. It removes the full gradient estimated by an exponentially moving average (EMA) of historical stochastic gradients, and then leverages stochastic gradient noise for improved generalization. Moreover, we provide theoretical validation for the EMA approximation and prove the convergence of F-SAM on non-convex problems. Extensive experiments demonstrate the superior generalization performance and robustness of F-SAM over vanilla SAM. Code is available at https://github.com/nblt/F-SAM.
Paper Structure (31 sections, 5 theorems, 44 equations, 7 figures, 6 tables, 1 algorithm)

This paper contains 31 sections, 5 theorems, 44 equations, 7 figures, 6 tables, 1 algorithm.

Key Result

Theorem 1

Suppose Assumption assumption1, assumption2, and assumption3 hold. Assume that SAM uses SGD as the base optimizer with a learning rate $\gamma$ to update the model parameter in Eqn. eqn:gradient. Then by setting $\lambda=1-C\gamma^{2/3}$, after $T > C^\prime \gamma^{-2/3}$ training iterations, with where $C$ and $C'$ are two universal constants.

Figures (7)

  • Figure 1: Investigation on SAM's adversarial perturbation direction. We decompose the minibatch gradient $\nabla_\mathcal{B} L(\boldsymbol{w})$ into two components: the full gradient component and the remaining batch-specific stochastic gradient noise. Solely using the full gradient component leads to a dramatic generalization degradation, while only using the noise component enhances the generalization.
  • Figure 2: Performance comparison of different versions of SAM with SGD/AdamW as its base optimizer. In (a), SAM-full denotes the version of SAM using the full gradient component as the perturbation. In (b), SAM-db represents the SAM using different random-selected data batch for perturbation and its following minimization step. (c) compares SAM using different minibatch size to compute the perturbation but always fixing minibatch size of 128 for the following minimization step.
  • Figure 3: Results under different perturbation radii $\rho$.
  • Figure 4: Performance comparison with different batch sizes.
  • Figure A5: Results on enlarging the batch size of SAM's adversarial perturbation.
  • ...and 2 more figures

Theorems & Definitions (9)

  • Theorem 1
  • Lemma 1
  • Theorem 2
  • Theorem A3: Vector Bernstein
  • Corollary A1
  • proof
  • proof
  • proof
  • proof