μP$^2$: Effective Sharpness Aware Minimization Requires Layerwise Perturbation Scaling
Moritz Haas, Jin Xu, Volkan Cevher, Leena Chennuru Vankadara
TL;DR
This work provides a rigorous infinite-width analysis of Sharpness Aware Minimization (SAM) via Tensor Programs, revealing that standard SAM dynamics collapse to last-layer perturbations in wide networks. It introduces Maximal Update and Perturbation Parameterization (μP^2), a layerwise perturbation scaling that enables feature learning and effective perturbations in every layer, enabling hyperparameter transfer of both the learning rate $\eta$ and perturbation radius $\rho$ across model widths and architectures. The authors demonstrate, empirically across MLPs, ResNets, and Vision Transformers, that μP^2 achieves joint transfer of $\eta$ and $\rho$ and improves generalization relative to SGD and μP with global perturbations, while generalizing the approach to SAM variants like ASAM and SAM-ON. The results offer practical scaling insights for applying SAM in large-scale models and suggest general spectral perturbation conditions to derive layerwise perturbation rules for other perturbation strategies. The work lays a foundation for scalable, transferable SAM in diverse architectures and motivates future exploration of data-aware scaling and broader perturbation rules in deep learning optimization.
Abstract
Sharpness Aware Minimization (SAM) enhances performance across various neural architectures and datasets. As models are continually scaled up to improve performance, a rigorous understanding of SAM's scaling behaviour is paramount. To this end, we study the infinite-width limit of neural networks trained with SAM, using the Tensor Programs framework. Our findings reveal that the dynamics of standard SAM effectively reduce to applying SAM solely in the last layer in wide neural networks, even with optimal hyperparameters. In contrast, we identify a stable parameterization with layerwise perturbation scaling, which we call $\textit{Maximal Update and Perturbation Parameterization}$ ($μ$P$^2$), that ensures all layers are both feature learning and effectively perturbed in the limit. Through experiments with MLPs, ResNets and Vision Transformers, we empirically demonstrate that $μ$P$^2$ achieves hyperparameter transfer of the joint optimum of learning rate and perturbation radius across model scales. Moreover, we provide an intuitive condition to derive $μ$P$^2$ for other perturbation rules like Adaptive SAM and SAM-ON, also ensuring balanced perturbation effects across all layers.
