Table of Contents
Fetching ...

Adaptively Sampling-Reusing-Mixing Decomposed Gradients to Speed Up Sharpness Aware Minimization

Jiaxin Deng, Junbiao Pang

TL;DR

Sharpness-Aware Minimization (SAM) improves generalization but doubles SGD cost due to two gradient passes. ARSAM identifies that SAM gradients decompose into the SGD gradient and the PSF, whose influence grows during training, and introduces adaptive sampling and PSF reuse via an autoregressive schedule to cut cost while preserving flatness-driven generalization. The method is supported by convergence analysis and extensive experiments showing ~40% speedup with accuracy on par with SAM across CNNs, and general applicability to segmentation, pose estimation, and quantization. This offers a practical, broadly applicable acceleration for sharpness-aware optimization in vision tasks.

Abstract

Sharpness-Aware Minimization (SAM) improves model generalization but doubles the computational cost of Stochastic Gradient Descent (SGD) by requiring twice the gradient calculations per optimization step. To mitigate this, we propose Adaptively sampling-Reusing-mixing decomposed gradients to significantly accelerate SAM (ARSAM). Concretely, we firstly discover that SAM's gradient can be decomposed into the SGD gradient and the Projection of the Second-order gradient onto the First-order gradient (PSF). Furthermore, we observe that the SGD gradient and PSF dynamically evolve during training, emphasizing the growing role of the PSF to achieve a flat minima. Therefore, ARSAM is proposed to the reused PSF and the timely updated PSF still maintain the model's generalization ability. Extensive experiments show that ARSAM achieves state-of-the-art accuracies comparable to SAM across diverse network architectures. On CIFAR-10/100, ARSAM is comparable to SAM while providing a speedup of about 40\%. Moreover, ARSAM accelerates optimization for the various challenge tasks (\textit{e.g.}, human pose estimation, and model quantization) without sacrificing performance, demonstrating its broad practicality.% The code is publicly accessible at: https://github.com/ajiaaa/ARSAM.

Adaptively Sampling-Reusing-Mixing Decomposed Gradients to Speed Up Sharpness Aware Minimization

TL;DR

Sharpness-Aware Minimization (SAM) improves generalization but doubles SGD cost due to two gradient passes. ARSAM identifies that SAM gradients decompose into the SGD gradient and the PSF, whose influence grows during training, and introduces adaptive sampling and PSF reuse via an autoregressive schedule to cut cost while preserving flatness-driven generalization. The method is supported by convergence analysis and extensive experiments showing ~40% speedup with accuracy on par with SAM across CNNs, and general applicability to segmentation, pose estimation, and quantization. This offers a practical, broadly applicable acceleration for sharpness-aware optimization in vision tasks.

Abstract

Sharpness-Aware Minimization (SAM) improves model generalization but doubles the computational cost of Stochastic Gradient Descent (SGD) by requiring twice the gradient calculations per optimization step. To mitigate this, we propose Adaptively sampling-Reusing-mixing decomposed gradients to significantly accelerate SAM (ARSAM). Concretely, we firstly discover that SAM's gradient can be decomposed into the SGD gradient and the Projection of the Second-order gradient onto the First-order gradient (PSF). Furthermore, we observe that the SGD gradient and PSF dynamically evolve during training, emphasizing the growing role of the PSF to achieve a flat minima. Therefore, ARSAM is proposed to the reused PSF and the timely updated PSF still maintain the model's generalization ability. Extensive experiments show that ARSAM achieves state-of-the-art accuracies comparable to SAM across diverse network architectures. On CIFAR-10/100, ARSAM is comparable to SAM while providing a speedup of about 40\%. Moreover, ARSAM accelerates optimization for the various challenge tasks (\textit{e.g.}, human pose estimation, and model quantization) without sacrificing performance, demonstrating its broad practicality.% The code is publicly accessible at: https://github.com/ajiaaa/ARSAM.

Paper Structure

This paper contains 24 sections, 3 theorems, 20 equations, 2 figures, 8 tables, 1 algorithm.

Key Result

Theorem 1

Let $\nabla _\mathbf{w}^2L(\mathbf{w})$ be a hessian matrix with n eigenvalues, then the L2-PSF has an upper bound as follows: where $\delta _i$ is the $i$-th eigenvalue.

Figures (2)

  • Figure 1: ARSAM speed up SAM by adaptively reusing the decomposed PSF for the next several iterations.
  • Figure 2: The trends of $||\nabla L_i^{SGD}||$ and $||\nabla L_i^{PSF}||$ are similar across the different networks (i.e., Resnet-18, WideResNet-28-10, PyramidNet-110) during training (best viewed in color).

Theorems & Definitions (3)

  • Theorem 1
  • Theorem 2
  • Theorem 3