Table of Contents
Fetching ...

Sharpness-Aware Minimization Efficiently Selects Flatter Minima Late in Training

Zhanpeng Zhou, Mingze Wang, Yuchen Mao, Bingrui Li, Junchi Yan

TL;DR

This work investigates why Sharpness-Aware Minimization (SAM) improves generalization and reveals that SAM can efficiently select flatter minima when applied late in training. It introduces a switching scheme between SGD and SAM and uncovers a two-phase training dynamic: an initial fast escape from the sharp SGD minimum, followed by rapid convergence to a flatter minimum within the same valley. The authors provide both linear-stability and non-local analyses (P1–P4) to justify SAM's bias toward flatter minima and demonstrate exponential convergence under standard smoothness and PL conditions. Empirically, late-phase SAM matches full SAM in test performance and sharpness across datasets, and the findings extend to Adversarial Training (AT). The results highlight the importance of late-stage optimization in shaping generalization and robustness, with practical implications for reducing compute via short late-phase SAM runs.

Abstract

Sharpness-Aware Minimization (SAM) has substantially improved the generalization of neural networks under various settings. Despite the success, its effectiveness remains poorly understood. In this work, we discover an intriguing phenomenon in the training dynamics of SAM, shedding light on understanding its implicit bias towards flatter minima over Stochastic Gradient Descent (SGD). Specifically, we find that SAM efficiently selects flatter minima late in training. Remarkably, even a few epochs of SAM applied at the end of training yield nearly the same generalization and solution sharpness as full SAM training. Subsequently, we delve deeper into the underlying mechanism behind this phenomenon. Theoretically, we identify two phases in the learning dynamics after applying SAM late in training: i) SAM first escapes the minimum found by SGD exponentially fast; and ii) then rapidly converges to a flatter minimum within the same valley. Furthermore, we empirically investigate the role of SAM during the early training phase. We conjecture that the optimization method chosen in the late phase is more crucial in shaping the final solution's properties. Based on this viewpoint, we extend our findings from SAM to Adversarial Training.

Sharpness-Aware Minimization Efficiently Selects Flatter Minima Late in Training

TL;DR

This work investigates why Sharpness-Aware Minimization (SAM) improves generalization and reveals that SAM can efficiently select flatter minima when applied late in training. It introduces a switching scheme between SGD and SAM and uncovers a two-phase training dynamic: an initial fast escape from the sharp SGD minimum, followed by rapid convergence to a flatter minimum within the same valley. The authors provide both linear-stability and non-local analyses (P1–P4) to justify SAM's bias toward flatter minima and demonstrate exponential convergence under standard smoothness and PL conditions. Empirically, late-phase SAM matches full SAM in test performance and sharpness across datasets, and the findings extend to Adversarial Training (AT). The results highlight the importance of late-stage optimization in shaping generalization and robustness, with practical implications for reducing compute via short late-phase SAM runs.

Abstract

Sharpness-Aware Minimization (SAM) has substantially improved the generalization of neural networks under various settings. Despite the success, its effectiveness remains poorly understood. In this work, we discover an intriguing phenomenon in the training dynamics of SAM, shedding light on understanding its implicit bias towards flatter minima over Stochastic Gradient Descent (SGD). Specifically, we find that SAM efficiently selects flatter minima late in training. Remarkably, even a few epochs of SAM applied at the end of training yield nearly the same generalization and solution sharpness as full SAM training. Subsequently, we delve deeper into the underlying mechanism behind this phenomenon. Theoretically, we identify two phases in the learning dynamics after applying SAM late in training: i) SAM first escapes the minimum found by SGD exponentially fast; and ii) then rapidly converges to a flatter minimum within the same valley. Furthermore, we empirically investigate the role of SAM during the early training phase. We conjecture that the optimization method chosen in the late phase is more crucial in shaping the final solution's properties. Based on this viewpoint, we extend our findings from SAM to Adversarial Training.

Paper Structure

This paper contains 23 sections, 8 theorems, 41 equations, 16 figures, 2 tables.

Key Result

Theorem 4.1

Let $\boldsymbol{\theta}^\star$ be a global minimum that is linearly stable for SAM in eq:update_rule_of_USAM, and suppose Assumption ass:noise holds. Then we have $\left\lVert H(\boldsymbol{\theta}^\star)\right\rVert_F^2\left(1+\frac{\rho^2\gamma}{B}\left\lVert H(\boldsymbol{\theta}^\star)\right\rV

Figures (16)

  • Figure 1: SAM operates efficiently even when applied only during the final few epochs of training. (a) Test error curves for WideResNet-16-8 on CIFAR-10 trained with different strategies. The blue and green baseline curves represent training solely with SGD and SAM, respectively. The orange curve shows the test error of a neural network initially trained with SGD up to epoch $t = 175$, followed by SAM training up to epoch $T = 200$. The first 100 epochs are omitted for clarity. Curves are means and standard deviations over five trials with different random seeds. The detailed settings and hyper-parameters are described in \ref{['sec:prelim']}. (b) A schematic picture of the training trajectory after applying SAM late in training. The contour plot represents the loss landscape, with darker regions indicating lower loss. The orange dotted lines depict the path of the SAM iterator.
  • Figure 2: The illustration of the switching method. Blue dashed lines represent SGD training, while orange dashed lines represent SAM training. $\boldsymbol{\theta}^0$ denotes the random initialization. $t_1, t_2$ denotes two different switching points.
  • Figure 3: The impact of SAM training proportion on generalization/sharpness when switching from SGD to SAM. The generalization gap between the models $\boldsymbol{\theta}_{\textrm{SGD} \to \textrm{SAM}, t}^T$ and $\boldsymbol{\theta}_{\textrm{SAM}}^T$(top row) / the sharpness of $\boldsymbol{\theta}_{\textrm{SGD} \to \textrm{SAM}, t}^T$(bottom row) vs. the SAM training proportion of $\boldsymbol{\theta}^T_{\textrm{SGD} \to \textrm{SAM}, t}$. Dots represent the mean over three trials with different random seeds, and error bars indicate standard deviations. Results on more datasets and architectures can be found in \ref{['supp:exp_more_late_sam']}.
  • Figure 4: (a) Visualization of the two-phase dynamics for \ref{['example: toy']}. The horizontal gray line represents the set of the global minima $\mathcal{M}=\{(u,v):v=0\}$, where smaller values of $|u|$ correspond to flatter minima. The blue lines trace the SGD iterates, leading to $\boldsymbol{\theta}_\text{SGD}^\text{end}$, while the orange lines show the SAM iterates, which converge to a flatter minimum $\boldsymbol{\theta}_\text{SAM}^\text{end}$. Notably, $\boldsymbol{\theta}_\text{SGD}^\text{end}$ and $\boldsymbol{\theta}_\text{SAM}^\text{end}$ stay in the same valley around $\mathcal{M}$. (b) The exponentially fast escape from minima found by SGD. Train loss $\mathcal{L}_{\mathcal{D}_{\textrm{train}}}(\boldsymbol{\theta}_{\textrm{SGD}\to\textrm{SAM}, t}^{t+t'})$ v.s. the update step $t'$. Here, $t$ is the switching step, chosen to be sufficiently large for SGD to converge, and $t'$ is the number of updates after switching to SAM. The red line represents the mean over $100$ trials with different randomness. (c) SAM converges to a flatter minimum within the same valley as the one found by SGD. The loss $\mathcal{L}$(top row) / the second-order finite difference $\mathcal{L}"$(bottom row) of the interpolated model $\boldsymbol{\theta}_{\lambda}$ v.s. the interpolation coefficient $\lambda$. Here, $\boldsymbol{\theta}_{\lambda} = (1-\lambda) \boldsymbol{\theta}_{\textrm{SGD}\to\textrm{SAM}}^{\textrm{end}} + \lambda \boldsymbol{\theta}_{\textrm{SGD}}^{\textrm{end}}$, and $\mathcal{L}"(\boldsymbol{\theta}_\lambda) = \left(\mathcal{L}(\boldsymbol{\theta}_{\lambda+h}) + \mathcal{L}(\boldsymbol{\theta}_{\lambda-h}) - 2\mathcal{L}(\boldsymbol{\theta}_{\lambda})\right)/(2h)^2$, where $h$ is fixed to $0.1$. This is no barrier in the loss curve, indicating $\boldsymbol{\theta}_{\textrm{SGD}\to\textrm{SAM}}^{\textrm{end}}$ and $\boldsymbol{\theta}_{\textrm{SGD}}^{\textrm{end}}$ stay within the same valley. $\mathcal{L}"(\boldsymbol{\theta}_{\lambda})$ gradually increases with $\lambda$, implying $\boldsymbol{\theta}_{\textrm{SGD}\to\textrm{SAM}}^{\textrm{end}}$ is flatter than $\boldsymbol{\theta}_{\textrm{SGD}}^{\textrm{end}}$.
  • Figure 5: The impact of SAM training proportion on generalization/sharpness when switching from SAM to SGD. The generalization gap between the models $\boldsymbol{\theta}_{\textrm{SAM} \to \textrm{SGD}, t}^T$ and $\boldsymbol{\theta}_{\textrm{SAM}}^T$(top row)/the sharpness of $\boldsymbol{\theta}_{\textrm{SAM} \to \textrm{SGD}, t}^T$(bottom row) vs. the SAM training proportion of $\boldsymbol{\theta}^T_{\textrm{SAM} \to \textrm{SGD}, t}$. Dots represent the mean over three trials with different random seeds, and error bars indicate standard deviations. Results on more datasets and architectures can be found in \ref{['supp:exp_more_early_sam']}.
  • ...and 11 more figures

Theorems & Definitions (21)

  • Example 4.1
  • Definition 4.1: Loss-based linear stability of SAM
  • Theorem 4.1: P3. Proof in \ref{['suppl:proof_linear_stability']}.
  • Corollary 4.2: P1. Proof in \ref{['suppl:proof_linear_stability']}.
  • Definition 4.2: Sub-quadratic landscape for $p=1$
  • Proposition 4.1: P2. Proof in \ref{['suppl:proof_p2']}.
  • Theorem 4.3: P4. Proof in \ref{['suppl:proof_p4']}.
  • Remark C.1: Clarifications on $\mathbb{E}$
  • Remark C.2: $\boldsymbol{\xi}^B(\boldsymbol{\theta}^{t+1/2})$ and $\boldsymbol{\theta}^t-\eta \nabla \mathcal{L}(\boldsymbol{\theta}^{t+1/2})$ are uncorrelated
  • proof : Proof of Proposition \ref{['prop: SAM subquadratic same valley']}
  • ...and 11 more