Table of Contents
Fetching ...

Stabilizing Sharpness-aware Minimization Through A Simple Renormalization Strategy

Chengli Tan, Jiangshe Zhang, Junmin Liu, Yicheng Wang, Yunda Hao

TL;DR

The paper tackles the instability of sharpness-aware minimization (SAM) at larger learning rates by introducing Stable SAM (SSAM), a simple renormalization of the descent gradient to match the ascent gradient’s magnitude via a factor γ_t. The authors provide a theoretical framework based on uniform stability and convergence to show that SAM’s benefits are restricted to a narrow learning-rate regime, while SSAM extends this regime and yields improved generalization with only minor computational overhead. They validate the theory with extensive experiments across stability metrics, convergence on quadratic losses, and large-scale vision tasks, demonstrating that SSAM often outperforms SAM and, in many cases, SGD, while finding flatter minima as evidenced by Hessian analyses. Overall, SSAM offers a robust, plug-in enhancement to sharpness-aware optimization that broadens stable training and improves generalization in practical deep learning settings.

Abstract

Recently, sharpness-aware minimization (SAM) has attracted much attention because of its surprising effectiveness in improving generalization performance. However, compared to stochastic gradient descent (SGD), it is more prone to getting stuck at the saddle points, which as a result may lead to performance degradation. To address this issue, we propose a simple renormalization strategy, dubbed Stable SAM (SSAM), so that the gradient norm of the descent step maintains the same as that of the ascent step. Our strategy is easy to implement and flexible enough to integrate with SAM and its variants, almost at no computational cost. With elementary tools from convex optimization and learning theory, we also conduct a theoretical analysis of sharpness-aware training, revealing that compared to SGD, the effectiveness of SAM is only assured in a limited regime of learning rate. In contrast, we show how SSAM extends this regime of learning rate and then it can consistently perform better than SAM with the minor modification. Finally, we demonstrate the improved performance of SSAM on several representative data sets and tasks.

Stabilizing Sharpness-aware Minimization Through A Simple Renormalization Strategy

TL;DR

The paper tackles the instability of sharpness-aware minimization (SAM) at larger learning rates by introducing Stable SAM (SSAM), a simple renormalization of the descent gradient to match the ascent gradient’s magnitude via a factor γ_t. The authors provide a theoretical framework based on uniform stability and convergence to show that SAM’s benefits are restricted to a narrow learning-rate regime, while SSAM extends this regime and yields improved generalization with only minor computational overhead. They validate the theory with extensive experiments across stability metrics, convergence on quadratic losses, and large-scale vision tasks, demonstrating that SSAM often outperforms SAM and, in many cases, SGD, while finding flatter minima as evidenced by Hessian analyses. Overall, SSAM offers a robust, plug-in enhancement to sharpness-aware optimization that broadens stable training and improves generalization in practical deep learning settings.

Abstract

Recently, sharpness-aware minimization (SAM) has attracted much attention because of its surprising effectiveness in improving generalization performance. However, compared to stochastic gradient descent (SGD), it is more prone to getting stuck at the saddle points, which as a result may lead to performance degradation. To address this issue, we propose a simple renormalization strategy, dubbed Stable SAM (SSAM), so that the gradient norm of the descent step maintains the same as that of the ascent step. Our strategy is easy to implement and flexible enough to integrate with SAM and its variants, almost at no computational cost. With elementary tools from convex optimization and learning theory, we also conduct a theoretical analysis of sharpness-aware training, revealing that compared to SGD, the effectiveness of SAM is only assured in a limited regime of learning rate. In contrast, we show how SSAM extends this regime of learning rate and then it can consistently perform better than SAM with the minor modification. Finally, we demonstrate the improved performance of SSAM on several representative data sets and tasks.
Paper Structure (19 sections, 15 theorems, 63 equations, 10 figures, 2 tables, 1 algorithm)

This paper contains 19 sections, 15 theorems, 63 equations, 10 figures, 2 tables, 1 algorithm.

Key Result

Theorem 1

Let $S$ and $S^\prime$ denote two training sets i.i.d. sampled from the same data distribution $\mathfrak{D}$ such that $S$ and $S^\prime$ differ in at most one example. A learning algorithm $\mathcal{A}$ is $\varepsilon$-uniformly stable if and only if for all samples $S$ and $S^\prime$, the follow Furthermore, if $\mathcal{A}$ is $\varepsilon$-uniformly stable, the expected generalization error

Figures (10)

  • Figure 1: Loss curves of different optimizers to escape from the saddle point (namely, the origin) under different values of $\rho$. Following compagnoni2023sde, we approximate the identity matrix of dimension $d=20$ as the product of two square matrices and initialize them with elements sampled from $\mathcal{N}(0, 1.0e^{-4})$. We then train the linear autoencoder with different optimizers up to 500 epochs using a constant learning rate of $1.0e^{-3}$.
  • Figure 2: (a) Contour plot of function $f(x_1, x_2) = x_1^4/4 - x_1 x_2 + x_2^2/2$ and the symbol $(+)$ marks the global minima at $(-1, -1)$ and $(1, 1)$, respectively. (b) - (d) exhibit the rate of successful training as a function of the learning rate for different optimizers and perturbation radius $\rho$. Notice that the curve of SGD remains the same throughout these subplots since it does not depend on $\rho$.
  • Figure 3: Evolution of the ratio $\gamma_t$ of the gradient norm of the ascent step $\|\nabla F_{\Omega_t}(\boldsymbol{w}_t)\|_2$ to that of the descent step $\|\nabla F_{\Omega_t}(\boldsymbol{w}^{asc}_t)\|_2$ throughout training. Both neural networks are trained up to 200 epochs using the SAM optimizer with different perturbation radius $\rho\in \{0.01, 0.05, 0.2\}$.
  • Figure 4: Evolution of (a) parameter distance and (b) generalization gap as a function of epoch. The base model is a fully connected neural network and the data set is MNIST. All models are trained with a constant learning rate and neither momentum nor weight decay is employed.
  • Figure 5: Evolution of (a) parameter distance and (b) generalization gap as a function of epoch. The base model is LeNet and the data set is CIFAR-10. All models are trained with a constant learning rate and neither momentum nor weight decay is employed.
  • ...and 5 more figures

Theorems & Definitions (20)

  • Theorem 1: Generalization error under $\varepsilon$-uniformly stability
  • Lemma 2
  • Remark 3
  • Lemma 4
  • Theorem 5
  • Corollary 6
  • Lemma 7
  • Theorem 8
  • Remark 9
  • Theorem 10
  • ...and 10 more