Table of Contents
Fetching ...

Unveiling m-Sharpness Through the Structure of Stochastic Gradient Noise

Haocheng Luo, Mehrtash Harandi, Dinh Phung, Trung Le

TL;DR

The paper studies why m-sharpness improves generalization in m-SAM by developing a two-parameter stochastic framework that jointly analyzes learning rate and perturbation radius via a continuous-time SDE. It derives explicit SDEs for USAM and SAM variants, showing that stochastic gradient noise induces a variance-based sharpness regularization, and demonstrates that this regularization scales with the micro-batch size. Motivated by these insights, the authors introduce Reweighted SAM (RW-SAM), which weights samples by the magnitude of their stochastic gradient norms using a Gibbs-distribution-based scheme and a lightweight finite-difference estimator, achieving parallelizable gains that emulate m-SAM’s benefits with reduced overhead. Extensive experiments across CIFAR, ImageNet, ViT fine-tuning, and GLUE validate the theory and show RW-SAM consistently improves generalization over SAM while offering practical efficiency. The work provides both a theoretical lens on SAM dynamics and a scalable method to harness SGN-driven sharpness regularization in large-scale training.

Abstract

Sharpness-aware minimization (SAM) has emerged as a highly effective technique for improving model generalization, but its underlying principles are not fully understood. We investigated the phenomenon known as m-sharpness, where the performance of SAM improves monotonically as the micro-batch size for computing perturbations decreases. In practice, the empirical m-sharpness effect underpins the deployment of SAM in distributed training, yet a rigorous theoretical account has remained lacking. To provide a theoretical explanation for m-sharpness, we leverage an extended Stochastic Differential Equation (SDE) framework and analyze the structure of stochastic gradient noise (SGN) to characterize the dynamics of various SAM variants, including n-SAM and m-SAM. Our findings reveal that the stochastic noise introduced during SAM perturbations inherently induces a variance-based sharpness regularization effect. Motivated by our theoretical insights, we introduce Reweighted SAM (RW-SAM), which employs sharpness-weighted sampling to mimic the generalization benefits of m-SAM while remaining parallelizable. Comprehensive experiments validate the effectiveness of our theoretical analysis and proposed method.

Unveiling m-Sharpness Through the Structure of Stochastic Gradient Noise

TL;DR

The paper studies why m-sharpness improves generalization in m-SAM by developing a two-parameter stochastic framework that jointly analyzes learning rate and perturbation radius via a continuous-time SDE. It derives explicit SDEs for USAM and SAM variants, showing that stochastic gradient noise induces a variance-based sharpness regularization, and demonstrates that this regularization scales with the micro-batch size. Motivated by these insights, the authors introduce Reweighted SAM (RW-SAM), which weights samples by the magnitude of their stochastic gradient norms using a Gibbs-distribution-based scheme and a lightweight finite-difference estimator, achieving parallelizable gains that emulate m-SAM’s benefits with reduced overhead. Extensive experiments across CIFAR, ImageNet, ViT fine-tuning, and GLUE validate the theory and show RW-SAM consistently improves generalization over SAM while offering practical efficiency. The work provides both a theoretical lens on SAM dynamics and a scalable method to harness SGN-driven sharpness regularization in large-scale training.

Abstract

Sharpness-aware minimization (SAM) has emerged as a highly effective technique for improving model generalization, but its underlying principles are not fully understood. We investigated the phenomenon known as m-sharpness, where the performance of SAM improves monotonically as the micro-batch size for computing perturbations decreases. In practice, the empirical m-sharpness effect underpins the deployment of SAM in distributed training, yet a rigorous theoretical account has remained lacking. To provide a theoretical explanation for m-sharpness, we leverage an extended Stochastic Differential Equation (SDE) framework and analyze the structure of stochastic gradient noise (SGN) to characterize the dynamics of various SAM variants, including n-SAM and m-SAM. Our findings reveal that the stochastic noise introduced during SAM perturbations inherently induces a variance-based sharpness regularization effect. Motivated by our theoretical insights, we introduce Reweighted SAM (RW-SAM), which employs sharpness-weighted sampling to mimic the generalization benefits of m-SAM while remaining parallelizable. Comprehensive experiments validate the effectiveness of our theoretical analysis and proposed method.

Paper Structure

This paper contains 38 sections, 24 theorems, 157 equations, 3 figures, 10 tables, 1 algorithm.

Key Result

Theorem 3.3

Under Assumption assumption:gaussian and mild regularity conditions, the solution of the following SDE usam-sde is an order-$(1,1)$ weak approximation of the discrete update of mini-batch USAM batch-usam with batch size $|\gamma|$:

Figures (3)

  • Figure 1: Speed of escaping poor minima, measured by test accuracy.
  • Figure 2: Variance of SGN over iterations.
  • Figure 4: Performance comparison on CIFAR-10 across different noise ratios

Theorems & Definitions (45)

  • Definition 3.2: Two‐parameter weak approximation
  • Theorem 3.3: Mini-batch USAM SDE - informal statement of Theorem \ref{['app:usam sde theorem new']}, adapted from Theorem 3.2 of compagnoni2023sde
  • Theorem 3.4: n-USAM SDE - informal statement of Theorem \ref{['app:n-usam sde']}
  • Theorem 3.5: m-USAM SDE - informal statement of Theorem \ref{['app:m-usam sde']}
  • Theorem 3.6: Mini-batch SAM SDE - informal statement of Theorem \ref{['app:SAM sde theorem new']}, adapted from Theorem 3.5 of compagnoni2023sde
  • Theorem 3.7: n-SAM SDE - informal statement of Theorem \ref{['app:n-SAM sde']}
  • Theorem 3.8: m-SAM SDE - informal statement of Theorem \ref{['app:m-sam sde']}
  • Proposition 3.9
  • Definition A.1
  • Theorem A.2
  • ...and 35 more