Table of Contents
Fetching ...

Forget Sharpness: Perturbed Forgetting of Model Biases Within SAM Dynamics

Ankit Vani, Frederick Tung, Gabriel L. Oliveira, Hossein Sharifi-Noghabi

TL;DR

This work challenges the prevailing view that SAM’s generalization stems from minimizing loss sharpness, proposing a perturbed forgetting mechanism where SAM perturbations discard undesirable model biases to improve learning signals. By linking perturbations to information bottleneck theory, the authors show that forgetting in the perturbation can reduce both model-specific and input-exposed information, correlating more strongly with generalization than flatness. They introduce Output Bias Forgetting (OBF), an alternative perturbation that targets biases revealed through model outputs and can yield superior generalization on ImageNet and robustness benchmarks, often without the lowest loss sharpness. The findings suggest that the dynamics of SAM—and not flatness alone—drive its generalization benefits, with practical implications for designing perturbations that selectively forget biases to improve transfer and robustness in deep models.

Abstract

Despite attaining high empirical generalization, the sharpness of models trained with sharpness-aware minimization (SAM) do not always correlate with generalization error. Instead of viewing SAM as minimizing sharpness to improve generalization, our paper considers a new perspective based on SAM's training dynamics. We propose that perturbations in SAM perform perturbed forgetting, where they discard undesirable model biases to exhibit learning signals that generalize better. We relate our notion of forgetting to the information bottleneck principle, use it to explain observations like the better generalization of smaller perturbation batches, and show that perturbed forgetting can exhibit a stronger correlation with generalization than flatness. While standard SAM targets model biases exposed by the steepest ascent directions, we propose a new perturbation that targets biases exposed through the model's outputs. Our output bias forgetting perturbations outperform standard SAM, GSAM, and ASAM on ImageNet, robustness benchmarks, and transfer to CIFAR-{10,100}, while sometimes converging to sharper regions. Our results suggest that the benefits of SAM can be explained by alternative mechanistic principles that do not require flatness of the loss surface.

Forget Sharpness: Perturbed Forgetting of Model Biases Within SAM Dynamics

TL;DR

This work challenges the prevailing view that SAM’s generalization stems from minimizing loss sharpness, proposing a perturbed forgetting mechanism where SAM perturbations discard undesirable model biases to improve learning signals. By linking perturbations to information bottleneck theory, the authors show that forgetting in the perturbation can reduce both model-specific and input-exposed information, correlating more strongly with generalization than flatness. They introduce Output Bias Forgetting (OBF), an alternative perturbation that targets biases revealed through model outputs and can yield superior generalization on ImageNet and robustness benchmarks, often without the lowest loss sharpness. The findings suggest that the dynamics of SAM—and not flatness alone—drive its generalization benefits, with practical implications for designing perturbations that selectively forget biases to improve transfer and robustness in deep models.

Abstract

Despite attaining high empirical generalization, the sharpness of models trained with sharpness-aware minimization (SAM) do not always correlate with generalization error. Instead of viewing SAM as minimizing sharpness to improve generalization, our paper considers a new perspective based on SAM's training dynamics. We propose that perturbations in SAM perform perturbed forgetting, where they discard undesirable model biases to exhibit learning signals that generalize better. We relate our notion of forgetting to the information bottleneck principle, use it to explain observations like the better generalization of smaller perturbation batches, and show that perturbed forgetting can exhibit a stronger correlation with generalization than flatness. While standard SAM targets model biases exposed by the steepest ascent directions, we propose a new perturbation that targets biases exposed through the model's outputs. Our output bias forgetting perturbations outperform standard SAM, GSAM, and ASAM on ImageNet, robustness benchmarks, and transfer to CIFAR-{10,100}, while sometimes converging to sharper regions. Our results suggest that the benefits of SAM can be explained by alternative mechanistic principles that do not require flatness of the loss surface.
Paper Structure (43 sections, 15 equations, 3 figures, 5 tables, 1 algorithm)

This paper contains 43 sections, 15 equations, 3 figures, 5 tables, 1 algorithm.

Figures (3)

  • Figure 1: A simplified illustration of our mechanistic perturbed forgetting perspective of sharpness-aware minimization (SAM). We treat perturbations in each step of SAM as an opportunity to forget undesirable model biases. Here, the presence of 'grassy' or 'sandy' features spuriously contribute to the prediction of 'cow.' When gradient descent (GD) can strengthen these biases, leading to overfitting, the perturbation of SAM takes an ascent step to 'forget' them to allow computing a less biased gradient. *Not illustrated: this gradient is used to take a GD step at the unperturbed weights.
  • Figure 2: Kendall's $\tau$ correlation of accuracy with sharpness and mutual information metrics averaged over epochs for models trained with different SAM perturbations on CIFAR-10. We train models with perturbation batch size $m\in\{2^k \mid k \in \{0,\ldots,9\}\}$ for each perturbation. Shaded regions indicate the $p$-value estimated with a permutation test, and we show solid lines only when the $p\text{-value} \leq 0.05$.
  • Figure 3: Effect of the hyperparameters $\gamma$ (with fixed $\lambda=1/C$, where $C$ is the number of classes) and $\lambda$ (with fixed $\gamma=1$) on ImageNet top-$1$ accuracy for ViT-S/32 trained using the output bias forgetting (OBF) perturbation in SAM.