Table of Contents
Fetching ...

μP$^2$: Effective Sharpness Aware Minimization Requires Layerwise Perturbation Scaling

Moritz Haas, Jin Xu, Volkan Cevher, Leena Chennuru Vankadara

TL;DR

This work provides a rigorous infinite-width analysis of Sharpness Aware Minimization (SAM) via Tensor Programs, revealing that standard SAM dynamics collapse to last-layer perturbations in wide networks. It introduces Maximal Update and Perturbation Parameterization (μP^2), a layerwise perturbation scaling that enables feature learning and effective perturbations in every layer, enabling hyperparameter transfer of both the learning rate $\eta$ and perturbation radius $\rho$ across model widths and architectures. The authors demonstrate, empirically across MLPs, ResNets, and Vision Transformers, that μP^2 achieves joint transfer of $\eta$ and $\rho$ and improves generalization relative to SGD and μP with global perturbations, while generalizing the approach to SAM variants like ASAM and SAM-ON. The results offer practical scaling insights for applying SAM in large-scale models and suggest general spectral perturbation conditions to derive layerwise perturbation rules for other perturbation strategies. The work lays a foundation for scalable, transferable SAM in diverse architectures and motivates future exploration of data-aware scaling and broader perturbation rules in deep learning optimization.

Abstract

Sharpness Aware Minimization (SAM) enhances performance across various neural architectures and datasets. As models are continually scaled up to improve performance, a rigorous understanding of SAM's scaling behaviour is paramount. To this end, we study the infinite-width limit of neural networks trained with SAM, using the Tensor Programs framework. Our findings reveal that the dynamics of standard SAM effectively reduce to applying SAM solely in the last layer in wide neural networks, even with optimal hyperparameters. In contrast, we identify a stable parameterization with layerwise perturbation scaling, which we call $\textit{Maximal Update and Perturbation Parameterization}$ ($μ$P$^2$), that ensures all layers are both feature learning and effectively perturbed in the limit. Through experiments with MLPs, ResNets and Vision Transformers, we empirically demonstrate that $μ$P$^2$ achieves hyperparameter transfer of the joint optimum of learning rate and perturbation radius across model scales. Moreover, we provide an intuitive condition to derive $μ$P$^2$ for other perturbation rules like Adaptive SAM and SAM-ON, also ensuring balanced perturbation effects across all layers.

μP$^2$: Effective Sharpness Aware Minimization Requires Layerwise Perturbation Scaling

TL;DR

This work provides a rigorous infinite-width analysis of Sharpness Aware Minimization (SAM) via Tensor Programs, revealing that standard SAM dynamics collapse to last-layer perturbations in wide networks. It introduces Maximal Update and Perturbation Parameterization (μP^2), a layerwise perturbation scaling that enables feature learning and effective perturbations in every layer, enabling hyperparameter transfer of both the learning rate and perturbation radius across model widths and architectures. The authors demonstrate, empirically across MLPs, ResNets, and Vision Transformers, that μP^2 achieves joint transfer of and and improves generalization relative to SGD and μP with global perturbations, while generalizing the approach to SAM variants like ASAM and SAM-ON. The results offer practical scaling insights for applying SAM in large-scale models and suggest general spectral perturbation conditions to derive layerwise perturbation rules for other perturbation strategies. The work lays a foundation for scalable, transferable SAM in diverse architectures and motivates future exploration of data-aware scaling and broader perturbation rules in deep learning optimization.

Abstract

Sharpness Aware Minimization (SAM) enhances performance across various neural architectures and datasets. As models are continually scaled up to improve performance, a rigorous understanding of SAM's scaling behaviour is paramount. To this end, we study the infinite-width limit of neural networks trained with SAM, using the Tensor Programs framework. Our findings reveal that the dynamics of standard SAM effectively reduce to applying SAM solely in the last layer in wide neural networks, even with optimal hyperparameters. In contrast, we identify a stable parameterization with layerwise perturbation scaling, which we call (P), that ensures all layers are both feature learning and effectively perturbed in the limit. Through experiments with MLPs, ResNets and Vision Transformers, we empirically demonstrate that P achieves hyperparameter transfer of the joint optimum of learning rate and perturbation radius across model scales. Moreover, we provide an intuitive condition to derive P for other perturbation rules like Adaptive SAM and SAM-ON, also ensuring balanced perturbation effects across all layers.

Paper Structure

This paper contains 57 sections, 26 theorems, 151 equations, 27 figures, 9 tables.

Key Result

Proposition 1

Under $\mu$P with the standard eq:bcd_sam_rule_global update rule and default perturbation given in eq:bcd_sam_rule_global, the output function becomes unbounded after the first update step in the infinite-width limit for any fixed, positive learning rate $\eta>0$ and perturbation radius $\rho>0$.

Figures (27)

  • Figure 1: Left and center($\mathbf{\boldsymbol{\mu} P^2}$ transfers both $\eta$ and $\rho$): Test accuracy as a function of learning rate $\eta$ and perturbation radius $\rho$ of a 3-layer MLP in $\mu$P trained with SAM on CIFAR10 for various widths with global perturbation scaling $\rho\cdot n^{-1/2}$ (left) and our layerwise perturbations scaling $\mu$P$^2$ (right), averaged over 3 independent runs. '$\times$' denotes the optimum. Blue contours (the darker, the wider) denote the region within $1\%$ of the optimal test accuracy smoothened with a Gaussian filter. Grey regions (the lighter, the wider) denote the unstable regime below $30\%$ test accuracy. Right($\mathbf{\boldsymbol{\mu} P^2}$ improves generalization): Same as left but sliced at the optimal learning rate of both parameterizations for width $4096$ with the base optimizer SGD in $\mu$P (dashed line) as a baseline. Average and $2\sigma$-CI from $16$ independent runs. Global perturbation scaling $\rho\cdot n^{-1/2}$ achieves a width-independent critical perturbation radius at which training becomes unstable, but does not consistently improve over SGD in $\mu$P and does not transfer the optimal $(\eta,\rho)$. $\mu$P$^2$ achieves joint transfer in $(\eta,\rho)$ and improves generalization performance.
  • Figure 2: (\ref{['eq:bcd_sam_rule_global']} effectively only perturbs the last layer) Layerwise weight perturbations (top) and normalized activation updates $\|\Delta x^l\|_2$ (bottom) for SAM, last-layer SAM and SGD as a baseline across widths after training a $3$-layer MLP in $\mu$P with global perturbation scaling $\rho\cdot n^{-1/2}$ for 20 epochs on CIFAR10. Average and CI are computed from $4$ independent runs. Perturbations are normalized by the weight spectral norm to measure their effect on the layer's output. Activation updates are normalized by $\sqrt{\text{dim}(\Delta x^l)}$ to measure coordinatewise updates. We provide more neural network statistics in \ref{['sec:llsam']}.
  • Figure 3: (Perturbation phase characterization of bcd-parameterizations) Given a choice of layerwise initialization and learning rate scalings $\{b_l,c_l\}_{l\in [L+1]}$, the maximal feature perturbation scaling $\tilde{r}$ and the last-layer perturbation scaling $d+d_{L+1}$ determine whether a $bcd$-parameterization is unstable, has effective SGD dynamics, effective perturbations in some but not all layers or whether it may have effective perturbations in all layers. In SP or NTP (left), there does not exist a choice of perturbation scalings that achieves effective perturbations in all layers, whereas in $\mu$P (right), there is a unique choice as provided in \ref{['thm:perturbation_scaling']}.
  • Figure 4: ($\rho$-transfer in ViTs) Training a ViT with SAM in $\mu$P$^2$ on ImageNet1K from scratch for 100 epochs yields $\rho$-transfer and large improvements over AdamW in $\mu$P (dashed lines).
  • Figure 5: (Stable training dynamics) SAM in $\mu$P$^2$ stabilizes training dynamics for a ResNet-18 with width multiplier $2$.
  • ...and 22 more figures

Theorems & Definitions (34)

  • Proposition 1: Instability of standard SAM parameterization in wide neural networks
  • Proposition 2: Global perturbation scaling is unstable or induces vanishing perturbations
  • Theorem 3: Perturbation nontriviality characterization
  • Theorem 4: Vanishing perturbation characterization
  • Theorem 5: Effective perturbation characterization
  • Theorem 6: Maximal Perturbation Parameterization (MPP)
  • Theorem D.1: Stability characterization
  • Theorem D.2: Nontriviality characterization
  • Theorem D.3: Feature learning characterization
  • Theorem D.4: Perturbation nontriviality characterization
  • ...and 24 more