Table of Contents
Fetching ...

Enhancing Sharpness-Aware Minimization by Learning Perturbation Radius

Xuehao Wang, Weisen Jiang, Shuai Fu, Yu Zhang

TL;DR

The paper tackles the sensitivity of sharpness-aware minimization (SAM) to the perturbation radius $\rho$ by introducing LETS, a bilevel framework that learns $\rho$ during training. By modeling $\rho$ as an upper-level variable and the SAM optimization as the lower-level problem, LETS computes a hypergradient to minimize the squared generalization gap $\frac{1}{2}(\mathcal{L}(\mathcal{D}^{vl}; \boldsymbol{\theta}^*(\rho)) - \mathcal{L}(\mathcal{D}^{tr}; \boldsymbol{\theta}^*(\rho)))^2$, with an alternating gradient-based procedure and a diagonal Hessian approximation for efficiency. LETS is general and can be integrated with any SAM variant, as demonstrated by LETS-ASAM, which replaces the lower-level step with ASAM while using LETS’ upper-level update. Empirically, LETS-SAM and LETS-ASAM achieve improved accuracy across CIFAR-10/100, ImageNet, IWSLT’14 DE-EN, and GLUE, show robustness to label noise and initialization of $\rho$, and yield flatter loss landscapes, indicating stronger generalization. These results suggest a practical path to more reliable SAM-based training in vision and language tasks, reducing the need for manual hyperparameter tuning of $\rho$.

Abstract

Sharpness-aware minimization (SAM) is to improve model generalization by searching for flat minima in the loss landscape. The SAM update consists of one step for computing the perturbation and the other for computing the update gradient. Within the two steps, the choice of the perturbation radius is crucial to the performance of SAM, but finding an appropriate perturbation radius is challenging. In this paper, we propose a bilevel optimization framework called LEarning the perTurbation radiuS (LETS) to learn the perturbation radius for sharpness-aware minimization algorithms. Specifically, in the proposed LETS method, the upper-level problem aims at seeking a good perturbation radius by minimizing the squared generalization gap between the training and validation losses, while the lower-level problem is the SAM optimization problem. Moreover, the LETS method can be combined with any variant of SAM. Experimental results on various architectures and benchmark datasets in computer vision and natural language processing demonstrate the effectiveness of the proposed LETS method in improving the performance of SAM.

Enhancing Sharpness-Aware Minimization by Learning Perturbation Radius

TL;DR

The paper tackles the sensitivity of sharpness-aware minimization (SAM) to the perturbation radius by introducing LETS, a bilevel framework that learns during training. By modeling as an upper-level variable and the SAM optimization as the lower-level problem, LETS computes a hypergradient to minimize the squared generalization gap , with an alternating gradient-based procedure and a diagonal Hessian approximation for efficiency. LETS is general and can be integrated with any SAM variant, as demonstrated by LETS-ASAM, which replaces the lower-level step with ASAM while using LETS’ upper-level update. Empirically, LETS-SAM and LETS-ASAM achieve improved accuracy across CIFAR-10/100, ImageNet, IWSLT’14 DE-EN, and GLUE, show robustness to label noise and initialization of , and yield flatter loss landscapes, indicating stronger generalization. These results suggest a practical path to more reliable SAM-based training in vision and language tasks, reducing the need for manual hyperparameter tuning of .

Abstract

Sharpness-aware minimization (SAM) is to improve model generalization by searching for flat minima in the loss landscape. The SAM update consists of one step for computing the perturbation and the other for computing the update gradient. Within the two steps, the choice of the perturbation radius is crucial to the performance of SAM, but finding an appropriate perturbation radius is challenging. In this paper, we propose a bilevel optimization framework called LEarning the perTurbation radiuS (LETS) to learn the perturbation radius for sharpness-aware minimization algorithms. Specifically, in the proposed LETS method, the upper-level problem aims at seeking a good perturbation radius by minimizing the squared generalization gap between the training and validation losses, while the lower-level problem is the SAM optimization problem. Moreover, the LETS method can be combined with any variant of SAM. Experimental results on various architectures and benchmark datasets in computer vision and natural language processing demonstrate the effectiveness of the proposed LETS method in improving the performance of SAM.
Paper Structure (28 sections, 1 theorem, 22 equations, 9 figures, 12 tables, 1 algorithm)

This paper contains 28 sections, 1 theorem, 22 equations, 9 figures, 12 tables, 1 algorithm.

Key Result

theorem 4

Let $b$ be the mini-batch size. If stepsize $\eta=\frac{1}{\gamma\sqrt{T}}$ and $\rho_t\leq \frac{\kappa}{\sqrt{T}}$ (where $\kappa>0$ is a constant), Algorithm LETS-SAM satisfies where the expectation is taken over the random training samples.

Figures (9)

  • Figure 1: Classification accuracy of SAM and ASAM with different $\rho$'s on MRPC using DeBERTa. As shown, the performance is sensitive to $\rho$.
  • Figure 2: Generalization gap w.r.t. training epochs on CIFAR-100. Best viewed in color.
  • Figure 3: Experimental results on IWSLT’14 DE-EN.
  • Figure 4: Testing accuracy on five datasets from GLUE.
  • Figure 5: Loss landscapes of different methods built on ResNet-18 for CIFAR-10, where x- and y-axes represent two orthogonal weight perturbations, while z-axis represents the loss value.
  • ...and 4 more figures

Theorems & Definitions (2)

  • theorem 4
  • proof : Proof of Theorem \ref{['thm']}