Sharpness-Aware Minimization Enhances Feature Quality via Balanced Learning
Jacob Mitchell Springer, Vaishnavh Nagarajan, Aditi Raghunathan
TL;DR
This work reframes Sharpness-Aware Minimization (SAM) beyond flatness-generalization arguments by showing a feature-diversifying mechanism that improves representations when data contain multiple predictive features. By decomposing SAM's phantom-parameter updates, the authors identify two causal effects: an implicit importance-weighting that evenly redistributes emphasis across examples, and a learning-rate-scaling that boosts learning on harder features. They validate these effects in toy settings and across real datasets (CelebA, Waterbirds, CIFAR-MNIST, DomainBed), demonstrating improved feature probing accuracy and better transfer to downstream tasks. The findings highlight SAM's potential to enhance out-of-distribution robustness without explicit upweighting or separate regularizers, with practical implications for robust representation learning and domain adaptation.
Abstract
Sharpness-Aware Minimization (SAM) has emerged as a promising alternative optimizer to stochastic gradient descent (SGD). The originally-proposed motivation behind SAM was to bias neural networks towards flatter minima that are believed to generalize better. However, recent studies have shown conflicting evidence on the relationship between flatness and generalization, suggesting that flatness does fully explain SAM's success. Sidestepping this debate, we identify an orthogonal effect of SAM that is beneficial out-of-distribution: we argue that SAM implicitly balances the quality of diverse features. SAM achieves this effect by adaptively suppressing well-learned features which gives remaining features opportunity to be learned. We show that this mechanism is beneficial in datasets that contain redundant or spurious features where SGD falls for the simplicity bias and would not otherwise learn all available features. Our insights are supported by experiments on real data: we demonstrate that SAM improves the quality of features in datasets containing redundant or spurious features, including CelebA, Waterbirds, CIFAR-MNIST, and DomainBed.
