Table of Contents
Fetching ...

Revisiting Sharpness-Aware Minimization: A More Faithful and Effective Implementation

Jianlong Chen, Zhiming Zhou

TL;DR

EXplicit Sharpness-Aware Minimization tackles the first by explicitly estimating the direction of the maximum during training, while addressing the second by crafting a search space that effectively leverages the gradient information at the multi-step ascent point.

Abstract

Sharpness-Aware Minimization (SAM) enhances generalization by minimizing the maximum training loss within a predefined neighborhood around the parameters. However, its practical implementation approximates this as gradient ascent(s) followed by applying the gradient at the ascent point to update the current parameters. This practice can be justified as approximately optimizing the objective by neglecting the (full) derivative of the ascent point with respect to the current parameters. Nevertheless, a direct and intuitive understanding of why using the gradient at the ascent point to update the current parameters works superiorly is still lacking. Our work bridges this gap by proposing a novel and intuitive interpretation. We show that the gradient at the single-step ascent point, \uline{when applied to the current parameters}, provides a better approximation of the direction from the current parameters toward the maximum within the local neighborhood than the local gradient. This improved approximation thereby enables a more direct escape from the maximum within the local neighborhood. Nevertheless, our analysis further reveals two issues. First, the approximation by the gradient at the single-step ascent point is often inaccurate. Second, the approximation quality may degrade as the number of ascent steps increases. To address these limitations, we propose in this paper eXplicit Sharpness-Aware Minimization (XSAM). It tackles the first by explicitly estimating the direction of the maximum during training, while addressing the second by crafting a search space that effectively leverages the gradient information at the multi-step ascent point. XSAM features a unified formulation that applies to both single-step and multi-step settings and only incurs negligible computational overhead. Extensive experiments demonstrate the consistent superiority of XSAM against existing counterparts.

Revisiting Sharpness-Aware Minimization: A More Faithful and Effective Implementation

TL;DR

EXplicit Sharpness-Aware Minimization tackles the first by explicitly estimating the direction of the maximum during training, while addressing the second by crafting a search space that effectively leverages the gradient information at the multi-step ascent point.

Abstract

Sharpness-Aware Minimization (SAM) enhances generalization by minimizing the maximum training loss within a predefined neighborhood around the parameters. However, its practical implementation approximates this as gradient ascent(s) followed by applying the gradient at the ascent point to update the current parameters. This practice can be justified as approximately optimizing the objective by neglecting the (full) derivative of the ascent point with respect to the current parameters. Nevertheless, a direct and intuitive understanding of why using the gradient at the ascent point to update the current parameters works superiorly is still lacking. Our work bridges this gap by proposing a novel and intuitive interpretation. We show that the gradient at the single-step ascent point, \uline{when applied to the current parameters}, provides a better approximation of the direction from the current parameters toward the maximum within the local neighborhood than the local gradient. This improved approximation thereby enables a more direct escape from the maximum within the local neighborhood. Nevertheless, our analysis further reveals two issues. First, the approximation by the gradient at the single-step ascent point is often inaccurate. Second, the approximation quality may degrade as the number of ascent steps increases. To address these limitations, we propose in this paper eXplicit Sharpness-Aware Minimization (XSAM). It tackles the first by explicitly estimating the direction of the maximum during training, while addressing the second by crafting a search space that effectively leverages the gradient information at the multi-step ascent point. XSAM features a unified formulation that applies to both single-step and multi-step settings and only incurs negligible computational overhead. Extensive experiments demonstrate the consistent superiority of XSAM against existing counterparts.
Paper Structure (32 sections, 2 theorems, 31 equations, 18 figures, 16 tables, 1 algorithm)

This paper contains 32 sections, 2 theorems, 31 equations, 18 figures, 16 tables, 1 algorithm.

Key Result

Proposition 1

Let $L: \mathbb{R}^n \to \mathbb{R}$ be a twice continuously differentiable function that admits a second-order approximation at $\vartheta_0$ with: Then there exists $\rho_0 > 0$ such that for all $\rho_m > \rho_0$:

Figures (18)

  • Figure 1: (a) Visualization of the local loss surface of single-step SAMon the hyperplane spanned by the gradient $g_0$ at the current parameter $\vartheta_0$ and the gradient $g_1$ at the single-step ascent point $\vartheta_1$. $\vartheta_0$ is set as the origin, the $Y$-axis is defined along the direction of $g_0$, and the $X$-axis is aligned with the component of $g_1$ perpendicular to $g_0$. The visualized arrows of gradients are set to have length $\rho$. We see that $\boldsymbol{g_1\!@\vartheta_0}$ (i.e., $\boldsymbol{g_1}$ applied to $\boldsymbol{\vartheta_0}$) points clearly closer to the direction from $\boldsymbol{\vartheta_0}$ toward the maximum within the local neighborhood than $\boldsymbol{g_0}$. The targeted direction is roughly from the origin to the upper-right corner in the figure. The loss along $g_1\!@\vartheta_0$ (i.e., $L(\vartheta_0 + \rho_m \cdot {g_1}/{\|g_1\|})$) is higher than that along $g_0$ (i.e., $L(\vartheta_0 + \rho_m \cdot {g_0}/{\|g_0\|})$), for sufficiently large $\rho_m$. (b) A simulation of multi-step SAM on a 2D test function. The approximation quality by the SAM gradient may get worse as the number of ascent steps increases. $g_2\!@\vartheta_0$ inferiorly identifies the direction from $\vartheta_0$ toward the maximum within the neighborhood (the upper-left high-loss region in yellow) than $g_1\!@\vartheta_0$.
  • Figure 1: XSAM
  • Figure 2: Slow variation of $\alpha^*$ during training.
  • Figure 3: (a) Training trajectory comparisons on 2D test function. (b)-(c) Test accuracy comparisons of ResNet-18 trained on CIFAR-100 in single-step and multi-step ($k=3$) settings with varying $\rho$.
  • Figure 4: XSAM robustness to the $\alpha^*$ update frequency.
  • ...and 13 more figures

Theorems & Definitions (8)

  • Proposition 1
  • Proposition 1
  • proof
  • Remark
  • Remark
  • proof
  • Remark
  • Remark