Unveiling m-Sharpness Through the Structure of Stochastic Gradient Noise
Haocheng Luo, Mehrtash Harandi, Dinh Phung, Trung Le
TL;DR
The paper studies why m-sharpness improves generalization in m-SAM by developing a two-parameter stochastic framework that jointly analyzes learning rate and perturbation radius via a continuous-time SDE. It derives explicit SDEs for USAM and SAM variants, showing that stochastic gradient noise induces a variance-based sharpness regularization, and demonstrates that this regularization scales with the micro-batch size. Motivated by these insights, the authors introduce Reweighted SAM (RW-SAM), which weights samples by the magnitude of their stochastic gradient norms using a Gibbs-distribution-based scheme and a lightweight finite-difference estimator, achieving parallelizable gains that emulate m-SAM’s benefits with reduced overhead. Extensive experiments across CIFAR, ImageNet, ViT fine-tuning, and GLUE validate the theory and show RW-SAM consistently improves generalization over SAM while offering practical efficiency. The work provides both a theoretical lens on SAM dynamics and a scalable method to harness SGN-driven sharpness regularization in large-scale training.
Abstract
Sharpness-aware minimization (SAM) has emerged as a highly effective technique for improving model generalization, but its underlying principles are not fully understood. We investigated the phenomenon known as m-sharpness, where the performance of SAM improves monotonically as the micro-batch size for computing perturbations decreases. In practice, the empirical m-sharpness effect underpins the deployment of SAM in distributed training, yet a rigorous theoretical account has remained lacking. To provide a theoretical explanation for m-sharpness, we leverage an extended Stochastic Differential Equation (SDE) framework and analyze the structure of stochastic gradient noise (SGN) to characterize the dynamics of various SAM variants, including n-SAM and m-SAM. Our findings reveal that the stochastic noise introduced during SAM perturbations inherently induces a variance-based sharpness regularization effect. Motivated by our theoretical insights, we introduce Reweighted SAM (RW-SAM), which employs sharpness-weighted sampling to mimic the generalization benefits of m-SAM while remaining parallelizable. Comprehensive experiments validate the effectiveness of our theoretical analysis and proposed method.
