Table of Contents
Fetching ...

Bilateral Sharpness-Aware Minimization for Flatter Minima

Jiaxin Deng, Junbiao Pang, Baochang Zhang, Qingming Huang

TL;DR

This work identifies a Flatness Indicator Problem in Sharpness-Aware Minimization (SAM) that arises from focusing only on MaxS. It introduces Min-Sharpness (MinS) and forms Bilateral Sharpness (BilS) to capture flatness in both gradient-ascent and gradient-descent directions, culminating in Bilateral SAM (BSAM). The authors prove convergence and demonstrate through extensive experiments across classification, transfer learning, pose estimation, and quantization that BSAM yields flatter minima and improved generalization and robustness relative to SAM and SGD. The approach establishes broader applicability and suggests that balancing perturbations on both sides of the current point better guides optimization toward flatter regions of the loss landscape.

Abstract

Sharpness-Aware Minimization (SAM) enhances generalization by reducing a Max-Sharpness (MaxS). Despite the practical success, we empirically found that the MAxS behind SAM's generalization enhancements face the "Flatness Indicator Problem" (FIP), where SAM only considers the flatness in the direction of gradient ascent, resulting in a next minimization region that is not sufficiently flat. A better Flatness Indicator (FI) would bring a better generalization of neural networks. Because SAM is a greedy search method in nature. In this paper, we propose to utilize the difference between the training loss and the minimum loss over the neighborhood surrounding the current weight, which we denote as Min-Sharpness (MinS). By merging MaxS and MinS, we created a better FI that indicates a flatter direction during the optimization. Specially, we combine this FI with SAM into the proposed Bilateral SAM (BSAM) which finds a more flatter minimum than that of SAM. The theoretical analysis proves that BSAM converges to local minima. Extensive experiments demonstrate that BSAM offers superior generalization performance and robustness compared to vanilla SAM across various tasks, i.e., classification, transfer learning, human pose estimation, and network quantization. Code is publicly available at: https://github.com/ajiaaa/BSAM.

Bilateral Sharpness-Aware Minimization for Flatter Minima

TL;DR

This work identifies a Flatness Indicator Problem in Sharpness-Aware Minimization (SAM) that arises from focusing only on MaxS. It introduces Min-Sharpness (MinS) and forms Bilateral Sharpness (BilS) to capture flatness in both gradient-ascent and gradient-descent directions, culminating in Bilateral SAM (BSAM). The authors prove convergence and demonstrate through extensive experiments across classification, transfer learning, pose estimation, and quantization that BSAM yields flatter minima and improved generalization and robustness relative to SAM and SGD. The approach establishes broader applicability and suggests that balancing perturbations on both sides of the current point better guides optimization toward flatter regions of the loss landscape.

Abstract

Sharpness-Aware Minimization (SAM) enhances generalization by reducing a Max-Sharpness (MaxS). Despite the practical success, we empirically found that the MAxS behind SAM's generalization enhancements face the "Flatness Indicator Problem" (FIP), where SAM only considers the flatness in the direction of gradient ascent, resulting in a next minimization region that is not sufficiently flat. A better Flatness Indicator (FI) would bring a better generalization of neural networks. Because SAM is a greedy search method in nature. In this paper, we propose to utilize the difference between the training loss and the minimum loss over the neighborhood surrounding the current weight, which we denote as Min-Sharpness (MinS). By merging MaxS and MinS, we created a better FI that indicates a flatter direction during the optimization. Specially, we combine this FI with SAM into the proposed Bilateral SAM (BSAM) which finds a more flatter minimum than that of SAM. The theoretical analysis proves that BSAM converges to local minima. Extensive experiments demonstrate that BSAM offers superior generalization performance and robustness compared to vanilla SAM across various tasks, i.e., classification, transfer learning, human pose estimation, and network quantization. Code is publicly available at: https://github.com/ajiaaa/BSAM.
Paper Structure (24 sections, 1 theorem, 27 equations, 6 figures, 6 tables, 1 algorithm)

This paper contains 24 sections, 1 theorem, 27 equations, 6 figures, 6 tables, 1 algorithm.

Key Result

Theorem 1

Suppose that the true gradient at $\mathbf{w}$ is ${\mathbf{g}_t} = \nabla L({\mathbf{w}_t};D)$. Assumptions assumpion:smooth and assumpion:bound_var hold. Let $\varsigma = \frac{||{\nabla _w}L(w){|_{w + {{\hat{\varepsilon} }^{\max }}(w)}}||}{||{\nabla _w}L(w){|_{w + {{\hat{\varepsilon} }^{\min }}(w where $Z_1 = \frac{{1 - \varsigma }}{{\tau (1 + {\varsigma ^2})}}$, $Z_2=(\tau {\sigma ^2} + {\rho

Figures (6)

  • Figure 1: Illustration of the notations of the MaxS, MinS and BilS.
  • Figure 2: The occurrence of gradient conflict under different $\rho^{min}$ and the variation of $\rho^{min}$ with learning rate in BSAM.
  • Figure 3: The cosine similarity between the gradients at the point $\mathbf{w}$ and the point $\mathbf{w} + {\hat{\bm{\varepsilon}} }^{\min }$ in training stage under the different settings of ${\rho}^{min}$.
  • Figure 4: The loss curves and accuracy ones of SAM and BSAM when WideResNet-28-10 was used.
  • Figure 5: The distribution of top-50 eigenvalues of Hessian on the test set of CIFAR-100 with SGD, SAM and BSAM.
  • ...and 1 more figures

Theorems & Definitions (3)

  • Theorem 1
  • proof
  • proof