Table of Contents
Fetching ...

Momentum-SAM: Sharpness Aware Minimization without Computational Overhead

Marlon Becker, Frederick Altrock, Benjamin Risse

TL;DR

Momentum-SAM (MSAM), which perturbs parameters in the direction of the accumulated momentum vector to achieve low sharpness without significant computational overhead or memory demands over SGD or Adam, is proposed.

Abstract

The recently proposed optimization algorithm for deep neural networks Sharpness Aware Minimization (SAM) suggests perturbing parameters before gradient calculation by a gradient ascent step to guide the optimization into parameter space regions of flat loss. While significant generalization improvements and thus reduction of overfitting could be demonstrated, the computational costs are doubled due to the additionally needed gradient calculation, making SAM unfeasible in case of limited computationally capacities. Motivated by Nesterov Accelerated Gradient (NAG) we propose Momentum-SAM (MSAM), which perturbs parameters in the direction of the accumulated momentum vector to achieve low sharpness without significant computational overhead or memory demands over SGD or Adam. We evaluate MSAM in detail and reveal insights on separable mechanisms of NAG, SAM and MSAM regarding training optimization and generalization. Code is available at https://github.com/MarlonBecker/MSAM.

Momentum-SAM: Sharpness Aware Minimization without Computational Overhead

TL;DR

Momentum-SAM (MSAM), which perturbs parameters in the direction of the accumulated momentum vector to achieve low sharpness without significant computational overhead or memory demands over SGD or Adam, is proposed.

Abstract

The recently proposed optimization algorithm for deep neural networks Sharpness Aware Minimization (SAM) suggests perturbing parameters before gradient calculation by a gradient ascent step to guide the optimization into parameter space regions of flat loss. While significant generalization improvements and thus reduction of overfitting could be demonstrated, the computational costs are doubled due to the additionally needed gradient calculation, making SAM unfeasible in case of limited computationally capacities. Motivated by Nesterov Accelerated Gradient (NAG) we propose Momentum-SAM (MSAM), which perturbs parameters in the direction of the accumulated momentum vector to achieve low sharpness without significant computational overhead or memory demands over SGD or Adam. We evaluate MSAM in detail and reveal insights on separable mechanisms of NAG, SAM and MSAM regarding training optimization and generalization. Code is available at https://github.com/MarlonBecker/MSAM.
Paper Structure (39 sections, 3 theorems, 12 equations, 21 figures, 7 tables, 1 algorithm)

This paper contains 39 sections, 3 theorems, 12 equations, 21 figures, 7 tables, 1 algorithm.

Key Result

Proposition 1

Let $\boldsymbol{\epsilon},\boldsymbol{v} \in \mathcal{W}$ with i.i.d. components $\boldsymbol{\epsilon}_i \sim\mathcal{N}(0,\sigma)$ for some $\sigma >0$, then for any $\rho >0$ where $\kappa := \frac{1}{|\mathcal{W}|}\mathrm{tr}[\mathrm{Hess}(L_{\mathcal{S}}(\boldsymbol{w}))]$.

Figures (21)

  • Figure 1: Schematic illustrations of optimization algorithms based on SGD. NAG calculates gradients after updating parameters with the momentum vector. SAM and MSAM calculate gradients at perturbed positions but remove perturbations again before the parameter update step. See Alg. \ref{['alg:MSAM']} for detailed description of the efficient implementation of MSAM.
  • Figure 2: WideRestNet-16-4 on CIFAR100 A: Test accuracy for positive and negative $\rho$ compared against SGD and NAG. B: Train and test accuracy on logarithmic scale. C: Cosine similarity between momentum vector $\boldsymbol{v}_{t-1}$ and gradient $g_t = \nabla L_{\mathcal{B}_t}(\boldsymbol{w}_t)$. Momentum vector direction has mostly negative slope during training and approaches zero at the end (caused by cosine learning rate scheduler).
  • Figure 3: A: Projecting $g_\text{MSAM}$ into the plane of $g_\text{SAM}$ and $g_\text{SGD}$ to measure SAM/MSAM similarity. B: Varying $\rho_\text{MSAM}$ until maximal similarity is reached (i.e. $\theta = 0$) and determine $\rho_0$. C: $\rho_\text{MSAM}$ at maximal similarity $\rho_0$ is close to generalization optimality ($\rho_\text{MSAM}^{\text{opt}}$, cf. Fig. \ref{['fig:rho_search']}A) for most epochs.
  • Figure 4: Sharpness (Eq. \ref{['eq:sharpness']}) after full training (200 epochs) for WideResNet-16-4 on CIFAR100 along different directions $\boldsymbol{\epsilon}$ scaled by $\rho$. A: Gradient direction (as used as perturbation in SAM). B: Momentum direction (as in MSAM). C: Random filter-normalized direction as in visualloss. Vertical line at $\rho^\text{opt}$ marks values of optimal test performance (cf. Fig. \ref{['fig:full_rho_scan_cifar']}A). MSAM and SAM are reducing their respective optimization objective best while MSAM reaches the lowest sharpness along random directions.
  • Figure A.1: Test (A) and train (B) accuracy for WideResNet-16-4 on CIFAR100 for different normalization schemes of MSAM in dependence on $\rho$. MSAM without normalization works equally well. If the perturbation $\epsilon$ is scaled by learning rate $\eta$ train performance (optimization) is increased while test performance (generalization) benefits only marginally.
  • ...and 16 more figures

Theorems & Definitions (3)

  • Proposition 1
  • Lemma 2
  • Theorem 3