Asynchronous Sharpness-Aware Minimization For Fast and Accurate Deep Learning
Junhyuk Jo, Jihyun Lim, Sunwoo Lee
TL;DR
This work tackles the computational bottleneck of Sharpness-Aware Minimization by introducing asynchronous SAM, which uses a fixed small gradient staleness and a system-aware perturbation batch size to hide the perturbation time behind the model update. The approach enables concurrent gradient ascent and descent, leveraging heterogeneous CPU/GPU resources while preserving the gradient-norm penalty that drives SAM's generalization. A convergence analysis under standard assumptions shows a bounded neighborhood of convergence, and extensive experiments across CIFAR-10/100, Flowers102, Speech Commands, Tiny-ImageNet, and Vision Transformer fine-tuning demonstrate comparable accuracy to SAM with training time close to SGD. The method offers practical benefits for real-world, resource-diverse systems and can be extended with gradient recycling or larger-scale scaling.
Abstract
Sharpness-Aware Minimization (SAM) is an optimization method that improves generalization performance of machine learning models. Despite its superior generalization, SAM has not been actively used in real-world applications due to its expensive computational cost. In this work, we propose a novel asynchronous-parallel SAM which achieves nearly the same gradient norm penalizing effect like the original SAM while breaking the data dependency between the model perturbation and the model update. The proposed asynchronous SAM can even entirely hide the model perturbation time by adjusting the batch size for the model perturbation in a system-aware manner. Thus, the proposed method enables to fully utilize heterogeneous system resources such as CPUs and GPUs. Our extensive experiments well demonstrate the practical benefits of the proposed asynchronous approach. E.g., the asynchronous SAM achieves comparable Vision Transformer fine-tuning accuracy (CIFAR-100) as the original SAM while having almost the same training time as SGD.
