Table of Contents
Fetching ...

Asynchronous Sharpness-Aware Minimization For Fast and Accurate Deep Learning

Junhyuk Jo, Jihyun Lim, Sunwoo Lee

TL;DR

This work tackles the computational bottleneck of Sharpness-Aware Minimization by introducing asynchronous SAM, which uses a fixed small gradient staleness and a system-aware perturbation batch size to hide the perturbation time behind the model update. The approach enables concurrent gradient ascent and descent, leveraging heterogeneous CPU/GPU resources while preserving the gradient-norm penalty that drives SAM's generalization. A convergence analysis under standard assumptions shows a bounded neighborhood of convergence, and extensive experiments across CIFAR-10/100, Flowers102, Speech Commands, Tiny-ImageNet, and Vision Transformer fine-tuning demonstrate comparable accuracy to SAM with training time close to SGD. The method offers practical benefits for real-world, resource-diverse systems and can be extended with gradient recycling or larger-scale scaling.

Abstract

Sharpness-Aware Minimization (SAM) is an optimization method that improves generalization performance of machine learning models. Despite its superior generalization, SAM has not been actively used in real-world applications due to its expensive computational cost. In this work, we propose a novel asynchronous-parallel SAM which achieves nearly the same gradient norm penalizing effect like the original SAM while breaking the data dependency between the model perturbation and the model update. The proposed asynchronous SAM can even entirely hide the model perturbation time by adjusting the batch size for the model perturbation in a system-aware manner. Thus, the proposed method enables to fully utilize heterogeneous system resources such as CPUs and GPUs. Our extensive experiments well demonstrate the practical benefits of the proposed asynchronous approach. E.g., the asynchronous SAM achieves comparable Vision Transformer fine-tuning accuracy (CIFAR-100) as the original SAM while having almost the same training time as SGD.

Asynchronous Sharpness-Aware Minimization For Fast and Accurate Deep Learning

TL;DR

This work tackles the computational bottleneck of Sharpness-Aware Minimization by introducing asynchronous SAM, which uses a fixed small gradient staleness and a system-aware perturbation batch size to hide the perturbation time behind the model update. The approach enables concurrent gradient ascent and descent, leveraging heterogeneous CPU/GPU resources while preserving the gradient-norm penalty that drives SAM's generalization. A convergence analysis under standard assumptions shows a bounded neighborhood of convergence, and extensive experiments across CIFAR-10/100, Flowers102, Speech Commands, Tiny-ImageNet, and Vision Transformer fine-tuning demonstrate comparable accuracy to SAM with training time close to SGD. The method offers practical benefits for real-world, resource-diverse systems and can be extended with gradient recycling or larger-scale scaling.

Abstract

Sharpness-Aware Minimization (SAM) is an optimization method that improves generalization performance of machine learning models. Despite its superior generalization, SAM has not been actively used in real-world applications due to its expensive computational cost. In this work, we propose a novel asynchronous-parallel SAM which achieves nearly the same gradient norm penalizing effect like the original SAM while breaking the data dependency between the model perturbation and the model update. The proposed asynchronous SAM can even entirely hide the model perturbation time by adjusting the batch size for the model perturbation in a system-aware manner. Thus, the proposed method enables to fully utilize heterogeneous system resources such as CPUs and GPUs. Our extensive experiments well demonstrate the practical benefits of the proposed asynchronous approach. E.g., the asynchronous SAM achieves comparable Vision Transformer fine-tuning accuracy (CIFAR-100) as the original SAM while having almost the same training time as SGD.

Paper Structure

This paper contains 17 sections, 6 theorems, 25 equations, 5 figures, 4 tables, 1 algorithm.

Key Result

Theorem 3.1

Assume the $\beta$-smooth non-convex loss function and the bounded gradient variance and norm. Then, if $\eta \leq \frac{1}{\beta}, Algorithm alg:asyncsam satisfies:$

Figures (5)

  • Figure 1: The cosine similarity between the latest gradient and the previous gradient. The similarities are measured for 1000 consecutive training iterations. All the curves consistently show the high similarity ($> 0.8$). This observation implies that the parameter space with respect to a certain data tends not to dramatically change during training.
  • Figure 2: a: The schematic illustration of SAM. The model is updated twice using Eq. \ref{['eq:sam']}. The two consecutive gradients, $\nabla l(w_t)$ and $\nabla l(w_{t+1})$ are observed to be quite similar to each other (See Fig. \ref{['fig:cossim']}). b: The schematic illustration of the proposed asynchronous SAM. Instead of using the latest gradient for the model perturbation, the staled gradients, $\nabla l(w_{t-1})$ and $\nabla l(w_t)$ are used to perturb $w_t$ and $w_{t+1}$, respectively. When $\tau$ is sufficiently small, the staled gradients are similar to the latest gradients and thus the model is expected to move toward the same minimum.
  • Figure 3: The CIFAR-10 training throughput Comparison (images/sec).
  • Figure 4: The CIFAR-10 learning curve comparison. Our proposed asynchronous SAM achieves nearly the same accuracy as the generalized SAM zhao2022penalizing while taking a similar amount of time as SGD. The generalized SAM achieves the best accuracy, however, it takes much more time than other methods. Interestingly, asynchronous SGD also achieves a training loss slightly lower than that of SGD.
  • Figure 5: The CIFAR-10 loss landscape comparison. The maximum z-axis is fixed to $10$. We clearly see that asynchronous SAM leads the model to a flat region similarly to the original SAM. The flatter the loss landscape, the better the generalization performance.

Theorems & Definitions (11)

  • Theorem 3.1
  • Lemma A.1
  • proof
  • Lemma A.2
  • proof
  • Lemma A.3
  • proof
  • Lemma A.4
  • proof
  • Theorem A.5
  • ...and 1 more