Table of Contents
Fetching ...

Sharpness-Aware Minimization Enhances Feature Quality via Balanced Learning

Jacob Mitchell Springer, Vaishnavh Nagarajan, Aditi Raghunathan

TL;DR

This work reframes Sharpness-Aware Minimization (SAM) beyond flatness-generalization arguments by showing a feature-diversifying mechanism that improves representations when data contain multiple predictive features. By decomposing SAM's phantom-parameter updates, the authors identify two causal effects: an implicit importance-weighting that evenly redistributes emphasis across examples, and a learning-rate-scaling that boosts learning on harder features. They validate these effects in toy settings and across real datasets (CelebA, Waterbirds, CIFAR-MNIST, DomainBed), demonstrating improved feature probing accuracy and better transfer to downstream tasks. The findings highlight SAM's potential to enhance out-of-distribution robustness without explicit upweighting or separate regularizers, with practical implications for robust representation learning and domain adaptation.

Abstract

Sharpness-Aware Minimization (SAM) has emerged as a promising alternative optimizer to stochastic gradient descent (SGD). The originally-proposed motivation behind SAM was to bias neural networks towards flatter minima that are believed to generalize better. However, recent studies have shown conflicting evidence on the relationship between flatness and generalization, suggesting that flatness does fully explain SAM's success. Sidestepping this debate, we identify an orthogonal effect of SAM that is beneficial out-of-distribution: we argue that SAM implicitly balances the quality of diverse features. SAM achieves this effect by adaptively suppressing well-learned features which gives remaining features opportunity to be learned. We show that this mechanism is beneficial in datasets that contain redundant or spurious features where SGD falls for the simplicity bias and would not otherwise learn all available features. Our insights are supported by experiments on real data: we demonstrate that SAM improves the quality of features in datasets containing redundant or spurious features, including CelebA, Waterbirds, CIFAR-MNIST, and DomainBed.

Sharpness-Aware Minimization Enhances Feature Quality via Balanced Learning

TL;DR

This work reframes Sharpness-Aware Minimization (SAM) beyond flatness-generalization arguments by showing a feature-diversifying mechanism that improves representations when data contain multiple predictive features. By decomposing SAM's phantom-parameter updates, the authors identify two causal effects: an implicit importance-weighting that evenly redistributes emphasis across examples, and a learning-rate-scaling that boosts learning on harder features. They validate these effects in toy settings and across real datasets (CelebA, Waterbirds, CIFAR-MNIST, DomainBed), demonstrating improved feature probing accuracy and better transfer to downstream tasks. The findings highlight SAM's potential to enhance out-of-distribution robustness without explicit upweighting or separate regularizers, with practical implications for robust representation learning and domain adaptation.

Abstract

Sharpness-Aware Minimization (SAM) has emerged as a promising alternative optimizer to stochastic gradient descent (SGD). The originally-proposed motivation behind SAM was to bias neural networks towards flatter minima that are believed to generalize better. However, recent studies have shown conflicting evidence on the relationship between flatness and generalization, suggesting that flatness does fully explain SAM's success. Sidestepping this debate, we identify an orthogonal effect of SAM that is beneficial out-of-distribution: we argue that SAM implicitly balances the quality of diverse features. SAM achieves this effect by adaptively suppressing well-learned features which gives remaining features opportunity to be learned. We show that this mechanism is beneficial in datasets that contain redundant or spurious features where SGD falls for the simplicity bias and would not otherwise learn all available features. Our insights are supported by experiments on real data: we demonstrate that SAM improves the quality of features in datasets containing redundant or spurious features, including CelebA, Waterbirds, CIFAR-MNIST, and DomainBed.
Paper Structure (64 sections, 27 equations, 10 figures, 2 tables)

This paper contains 64 sections, 27 equations, 10 figures, 2 tables.

Figures (10)

  • Figure 1: Illustration of the toy example. (i) The toy data distribution. We vary the complexity of the spiral component of the data by tightening the spiral. (ii) Decision boundary of classifiers trained with SGD and with LSAM along a 2D slice where the other feature.
  • Figure 2: (A) Test-set probing error of harder feature as a function of the phantom parameter $\rho$, for multiple complexities for the hard feature. SGD corresponds with $\rho=0$. (B) Phantom weight ratio as a function of $\rho$. We plot $\tilde{v}_\textsf{hard} / \tilde{v}_\textsf{easy}$ (solid lines) and $v_\textsf{hard} / v_\textsf{easy}$ (dashed lines) for different values of $\rho$ when running LSAM. Note that SGD corresponds to when $\rho=0$. As $\rho$ increases, the ratio $\tilde{v}_\textsf{hard} / \tilde{v}_\textsf{easy}$ increases.
  • Figure 3: Lorenz curves for the real and phantom importance weight $\lambda_i$ and $\tilde{\lambda}_i$. The dotted diagonal line represents the Lorenz curve for a uniform distribution. The closer this curve is to this diagonal, the more equally the importance weights are spread. In blue, we plot the Lorenz curves for an SGD checkpoint. In orange, we plot the Lorenz curves for an LSAM checkpoint. The update step gradient is computed at real parameter for SGD, and the phantom parameter for SAM. We include the curves for the toy (left), CelebA (center), and Waterbirds (right).
  • Figure 4: The learning rate re-weighting factor for the easy feature $\tilde{v}_\textsf{easy} / v_\textsf{easy}$ and the hard feature $\tilde{v}_\textsf{hard} / v_\textsf{hard}$, plotted as a distribution over batches in the dataset (see Equation \ref{['eq:toy_loss_gradient_decomposition']}).
  • Figure 5: Median importance weighting as a function of the contribution of each feature. We partition the data into bins based on the contribution of the easy and hard features ($y v_\textsf{easy} \Phi_\textsf{easy}$ and $y v_\textsf{easy} \Phi_\textsf{hard}$), as defined in Section \ref{['sec:two_feature_diversifying_effects']}. For each of these bins, we plot the median importance weight term $\lambda_i$ for the points in the bin. We include the corresponding plots for the toy (top), CelebA (center), and Waterbirds (bottom).
  • ...and 5 more figures