Table of Contents
Fetching ...

Lookbehind-SAM: k steps back, 1 step forward

Gonçalo Mordido, Pranshu Malviya, Aristide Baratin, Sarath Chandar

TL;DR

Lookbehind-SAM introduces a multistep ascent strategy with linear interpolation to stabilize the maximization/minimization loop in sharpness-aware minimization (SAM/ASAM). By performing k ascent steps to locate higher-loss perturbations and then blending fast and slow weights, it improves the loss-sharpness trade-off and yields stronger generalization, robustness to weight noise, and better lifelong learning performance. Empirical results across CNNs and transformers show consistent gains over SAM, ASAM, and multistep variants, with adaptive alpha and minibatch-switching offering practical benefits. The work highlights the method's applicability, discusses limitations, and provides open-source code for replication.

Abstract

Sharpness-aware minimization (SAM) methods have gained increasing popularity by formulating the problem of minimizing both loss value and loss sharpness as a minimax objective. In this work, we increase the efficiency of the maximization and minimization parts of SAM's objective to achieve a better loss-sharpness trade-off. By taking inspiration from the Lookahead optimizer, which uses multiple descent steps ahead, we propose Lookbehind, which performs multiple ascent steps behind to enhance the maximization step of SAM and find a worst-case perturbation with higher loss. Then, to mitigate the variance in the descent step arising from the gathered gradients across the multiple ascent steps, we employ linear interpolation to refine the minimization step. Lookbehind leads to a myriad of benefits across a variety of tasks. Particularly, we show increased generalization performance, greater robustness against noisy weights, as well as improved learning and less catastrophic forgetting in lifelong learning settings. Our code is available at https://github.com/chandar-lab/Lookbehind-SAM.

Lookbehind-SAM: k steps back, 1 step forward

TL;DR

Lookbehind-SAM introduces a multistep ascent strategy with linear interpolation to stabilize the maximization/minimization loop in sharpness-aware minimization (SAM/ASAM). By performing k ascent steps to locate higher-loss perturbations and then blending fast and slow weights, it improves the loss-sharpness trade-off and yields stronger generalization, robustness to weight noise, and better lifelong learning performance. Empirical results across CNNs and transformers show consistent gains over SAM, ASAM, and multistep variants, with adaptive alpha and minibatch-switching offering practical benefits. The work highlights the method's applicability, discusses limitations, and provides open-source code for replication.

Abstract

Sharpness-aware minimization (SAM) methods have gained increasing popularity by formulating the problem of minimizing both loss value and loss sharpness as a minimax objective. In this work, we increase the efficiency of the maximization and minimization parts of SAM's objective to achieve a better loss-sharpness trade-off. By taking inspiration from the Lookahead optimizer, which uses multiple descent steps ahead, we propose Lookbehind, which performs multiple ascent steps behind to enhance the maximization step of SAM and find a worst-case perturbation with higher loss. Then, to mitigate the variance in the descent step arising from the gathered gradients across the multiple ascent steps, we employ linear interpolation to refine the minimization step. Lookbehind leads to a myriad of benefits across a variety of tasks. Particularly, we show increased generalization performance, greater robustness against noisy weights, as well as improved learning and less catastrophic forgetting in lifelong learning settings. Our code is available at https://github.com/chandar-lab/Lookbehind-SAM.
Paper Structure (32 sections, 7 equations, 15 figures, 10 tables, 4 algorithms)

This paper contains 32 sections, 7 equations, 15 figures, 10 tables, 4 algorithms.

Figures (15)

  • Figure 1: Loss and sharpness trade-off using ResNet-34 trained on CIFAR-10. Darker shades indicate training with higher neighborhood sizes $\rho \in \{0.05, 0.1, 0.2\}$.
  • Figure 2: Illustration of Multistep-SAM (a) and Lookbehind-SAM (b) using $k=2$.
  • Figure 3: $m$-sharpness over multiple radius $r$ using ResNet-34 trained on CIFAR-10. Darker shades indicate training with higher neighborhood sizes $\rho \in \{0.05, 0.1, 0.2\}$ for SAM and $\rho \in \{0.5, 1.0, 2.0\}$ for ASAM. Lower sharpness is better.
  • Figure 4: Robustness against noisy weights at inference time. We plot the mean and standard deviation over 10 and 3 inference runs for CIFAR-10/100 and ImageNet, respectively. Higher accuracy is better.
  • Figure 5: Generalization performance (validation acc. %) between Multistep-SAM/SAM, Lookahead-SAM/ASAM, and Lookbehind-SAM/ASAM. Vanilla SAM and ASAM baselines with default $\rho$ are represented by the horizontal line.
  • ...and 10 more figures