Table of Contents
Fetching ...

Avoiding spurious sharpness minimization broadens applicability of SAM

Sidak Pal Singh, Hossein Mobahi, Atish Agarwala, Yann Dauphin

TL;DR

The paper reveals that SAM’s success in vision does not translate to NLP because SAM largely reduces sharpness via logit manipulation rather than improving the function’s geometry. It decomposes the sharpness gradient into a logit path and a functional path using a Gauss–Newton framework, showing NLP settings are logit-dominated. To address this, it proposes Functional-SAM, which targets the functional path, and a preconditioned-SAM variant to align SAM with optimizer geometry; combining them yields the best generalization across model scales and training regimes. Empirical results on decoder-only Transformers trained with C4 demonstrate consistent improvements over AdamW and SAM baselines, with flatter final solutions (lower Hessian traces and maxima) and robustness at large scales. The work pushes forward a more nuanced understanding of sharpness and its practical control, broadening curvature-regularization applicability to large language models.

Abstract

Curvature regularization techniques like Sharpness Aware Minimization (SAM) have shown great promise in improving generalization on vision tasks. However, we find that SAM performs poorly in domains like natural language processing (NLP), often degrading performance -- even with twice the compute budget. We investigate the discrepancy across domains and find that in the NLP setting, SAM is dominated by regularization of the logit statistics -- instead of improving the geometry of the function itself. We use this observation to develop an alternative algorithm we call Functional-SAM, which regularizes curvature only through modification of the statistics of the overall function implemented by the neural network, and avoids spurious minimization through logit manipulation. Furthermore, we argue that preconditioning the SAM perturbation also prevents spurious minimization, and when combined with Functional-SAM, it gives further improvements. Our proposed algorithms show improved performance over AdamW and SAM baselines when trained for an equal number of steps, in both fixed-length and Chinchilla-style training settings, at various model scales (including billion-parameter scale). On the whole, our work highlights the importance of more precise characterizations of sharpness in broadening the applicability of curvature regularization to large language models (LLMs).

Avoiding spurious sharpness minimization broadens applicability of SAM

TL;DR

The paper reveals that SAM’s success in vision does not translate to NLP because SAM largely reduces sharpness via logit manipulation rather than improving the function’s geometry. It decomposes the sharpness gradient into a logit path and a functional path using a Gauss–Newton framework, showing NLP settings are logit-dominated. To address this, it proposes Functional-SAM, which targets the functional path, and a preconditioned-SAM variant to align SAM with optimizer geometry; combining them yields the best generalization across model scales and training regimes. Empirical results on decoder-only Transformers trained with C4 demonstrate consistent improvements over AdamW and SAM baselines, with flatter final solutions (lower Hessian traces and maxima) and robustness at large scales. The work pushes forward a more nuanced understanding of sharpness and its practical control, broadening curvature-regularization applicability to large language models.

Abstract

Curvature regularization techniques like Sharpness Aware Minimization (SAM) have shown great promise in improving generalization on vision tasks. However, we find that SAM performs poorly in domains like natural language processing (NLP), often degrading performance -- even with twice the compute budget. We investigate the discrepancy across domains and find that in the NLP setting, SAM is dominated by regularization of the logit statistics -- instead of improving the geometry of the function itself. We use this observation to develop an alternative algorithm we call Functional-SAM, which regularizes curvature only through modification of the statistics of the overall function implemented by the neural network, and avoids spurious minimization through logit manipulation. Furthermore, we argue that preconditioning the SAM perturbation also prevents spurious minimization, and when combined with Functional-SAM, it gives further improvements. Our proposed algorithms show improved performance over AdamW and SAM baselines when trained for an equal number of steps, in both fixed-length and Chinchilla-style training settings, at various model scales (including billion-parameter scale). On the whole, our work highlights the importance of more precise characterizations of sharpness in broadening the applicability of curvature regularization to large language models (LLMs).

Paper Structure

This paper contains 24 sections, 18 equations, 5 figures, 6 tables.

Figures (5)

  • Figure 1: Evaluation loss curves of AdamW and SAM for Nanodo decoder-only Transformer model nanodo on the C4 dataset raffel2020exploring.
  • Figure 2: Normalized sharpness contributions $\textcolor{colorLogit}{\tau_{logit\xspace}}$ and $\textcolor{colorFunc}{\tau_{func\xspace}}$ show dramatically different trends across modalities. For ViT trained on ImageNet-1K (top left) and JFT (top right), $\textcolor{colorLogit}{\tau_{logit\xspace}}$ starts near $0$ but quickly increases to a comparable magnitude as $\textcolor{colorFunc}{\tau_{func\xspace}}$. For Transformer models trained on C4 (bottom left and bottom right), $\textcolor{colorLogit}{\tau_{logit\xspace}}\gg \textcolor{colorFunc}{\tau_{func\xspace}}$ after the first few steps of training. This suggests that the pathways to sharpness regularization are more imbalanced in NLP compared to vision settings, which may contribute to the poor performance of SAM in NLP settings. $\tau_{\textbf{cross}\xspace}$ (plotted in Appendix \ref{['app:sharp_plots']}) is usually negative, suggesting the two methods of sharpness regularization tend to be antagonistic.
  • Figure 3: Effect of increasing perturbation strengths for SAM and precond Functional-SAM at equal compute. We see that SAM (red squares) does worse than baseline at non-zero $\rho$, while precond Functional-SAM (orange circles), with the same compute costs, shows improvements (Nanodo trained on C4, $23.9$M parameters, $10$K steps). Loss is measured relative to best validation metric for precondFunctional-SAM for illustrative purposes.
  • Figure 4: Sharpness contributions $\textcolor{colorLogit}{\tau_{logit\xspace}}$, $\textcolor{colorFunc}{\tau_{func\xspace}}$ and $\tau_{\textbf{cross}\xspace}$ for various datasets. $\tau_{\textbf{cross}\xspace}$ tends to be negative for most of training.
  • Figure 5: Effect of one-shot (unstructured) global magnitude pruning: We see that sharpness minimization methods tend to degrade more gracefully as increasing number of parameters are pruned. Also, from this figure we can see that the performance gained imparted by Functional-SAM over AdamW is equivalent to setting about $25\%$ parameters of zero, and is thus significant.