Table of Contents
Fetching ...

A Unified Stability Analysis of SAM vs SGD: Role of Data Coherence and Emergence of Simplicity Bias

Wei-Kai Chang, Rajiv Khanna

TL;DR

This work develops a unified linear stability framework that connects data geometry, via a data-dependent coherence matrix, to the optimization dynamics of SGD, random perturbations, and Sharpness-Aware Minimization (SAM). By analyzing two-layer ReLU networks under exact interpolation, it shows that highly coherent feature directions across training examples yield greater stability and underpin a Simplicity Bias, where simpler, shared-feature solutions are favored. SAM is shown to tighten stability further by modifying the effective Hessian to $H_{\text{SAM}} = H\left(I + \frac{\rho}{\alpha}H\right)$, which not only flattens minima but also biases toward highly coherent, low-complexity solutions. The theory is complemented by empirical validation on a two-layer ReLU model and CIFAR-scale experiments, illustrating how coherence tracks across training and how SAM reduces coherence and feature rank. Overall, the paper offers a coherence-driven lens for understanding generalization and suggests new optimizer designs that leverage data geometry to induce desirable representations.

Abstract

Understanding the dynamics of optimization in deep learning is increasingly important as models scale. While stochastic gradient descent (SGD) and its variants reliably find solutions that generalize well, the mechanisms driving this generalization remain unclear. Notably, these algorithms often prefer flatter or simpler minima, particularly in overparameterized settings. Prior work has linked flatness to generalization, and methods like Sharpness-Aware Minimization (SAM) explicitly encourage flatness, but a unified theory connecting data structure, optimization dynamics, and the nature of learned solutions is still lacking. In this work, we develop a linear stability framework that analyzes the behavior of SGD, random perturbations, and SAM, particularly in two layer ReLU networks. Central to our analysis is a coherence measure that quantifies how gradient curvature aligns across data points, revealing why certain minima are stable and favored during training.

A Unified Stability Analysis of SAM vs SGD: Role of Data Coherence and Emergence of Simplicity Bias

TL;DR

This work develops a unified linear stability framework that connects data geometry, via a data-dependent coherence matrix, to the optimization dynamics of SGD, random perturbations, and Sharpness-Aware Minimization (SAM). By analyzing two-layer ReLU networks under exact interpolation, it shows that highly coherent feature directions across training examples yield greater stability and underpin a Simplicity Bias, where simpler, shared-feature solutions are favored. SAM is shown to tighten stability further by modifying the effective Hessian to , which not only flattens minima but also biases toward highly coherent, low-complexity solutions. The theory is complemented by empirical validation on a two-layer ReLU model and CIFAR-scale experiments, illustrating how coherence tracks across training and how SAM reduces coherence and feature rank. Overall, the paper offers a coherence-driven lens for understanding generalization and suggests new optimizer designs that leverage data geometry to induce desirable representations.

Abstract

Understanding the dynamics of optimization in deep learning is increasingly important as models scale. While stochastic gradient descent (SGD) and its variants reliably find solutions that generalize well, the mechanisms driving this generalization remain unclear. Notably, these algorithms often prefer flatter or simpler minima, particularly in overparameterized settings. Prior work has linked flatness to generalization, and methods like Sharpness-Aware Minimization (SAM) explicitly encourage flatness, but a unified theory connecting data structure, optimization dynamics, and the nature of learned solutions is still lacking. In this work, we develop a linear stability framework that analyzes the behavior of SGD, random perturbations, and SAM, particularly in two layer ReLU networks. Central to our analysis is a coherence measure that quantifies how gradient curvature aligns across data points, revealing why certain minima are stable and favored during training.

Paper Structure

This paper contains 35 sections, 15 theorems, 106 equations, 8 figures, 2 tables.

Key Result

Theorem 3.1

Given update rule eqn:RandomPerturbUpdateRule,

Figures (8)

  • Figure 2: 2-layer ReLU network. SAM imposes strong regularization on the maximum elementwise Hessian eigenvalue, and this also reduces the largest eigenvalue of the coherence matrix, which implies the stability condition is satisfied with smaller $\sigma$.
  • Figure 3: 2-layer ReLU network. We found that the SAM method can impose strong regulation on the maximum eigenvalue elementwise, and this also reduce the strengthen of the largest eigenvalue of the coherence matrix. It means that the stability condition can be satisfied with smaller $\sigma$. From our experiments, we find that the sharpness of the solution impose strong regulation of the eigenvalue of the coherence matrix.
  • Figure 4: 2-layer ReLU network. (Left) Comparison of SGD and SAM with different $\rho$. (Middle) We perform the same set of experiment with increased learning rate from 0.1 to 0.3. (Black to Red) (Right) SGD with different contrast loss strengthen (0.0, 0.1, 0.01). Through out the experiments, we find uniform shifting behavior for different algorithm with different strength but the relationship between $\max_i \lambda_{\max}(H_i)$ and $\lambda_{\max}(S)$ form strong regression line.
  • Figure : (a)
  • Figure : (a)
  • ...and 3 more figures

Theorems & Definitions (25)

  • Definition 1
  • Theorem 3.1
  • Theorem 3.2: (Simplified) Linear Stability of SAM
  • Theorem 3.3
  • Theorem 3.4: Coherence Characterization of Memorization
  • Definition 2
  • Theorem 3.5: SGD Stability of $(C, r)$-Generalizing Solutions
  • Theorem 3.6: SAM Stability of $(C, r)$-Generalizing Solutions
  • Definition 3
  • Lemma C.1
  • ...and 15 more