Table of Contents
Fetching ...

Why is SAM Robust to Label Noise?

Christina Baek, Zico Kolter, Aditi Raghunathan

TL;DR

This work tackles why Sharpness Aware Minimization (SAM) is unusually robust to label noise, especially under early stopping. It decouples the per-example gradient into a logit-scale term and a network Jacobian, showing that while the logit-scale can up-weight clean examples in linear settings, the nonlinear robustness primarily stems from perturbations to the Jacobian, which effectively regularize the network’s activations and last-layer weights. A two-layer deep linear network analysis demonstrates that Jacobian-based SAM updates behave like SGD with L2-type penalties on intermediate activations and final-layer weights, providing an interpretable regularization mechanism. Empirically, Jacobian-focused variants (J-SAM) recover most of SAM’s gains across architectures, and simple regularized SGD can closely approximate SAM’s performance under label noise, suggesting cheaper alternatives that capture the same robustness. Overall, the paper shifts the explanation from convergence-sharpness to optimization trajectory and Jacobian regularization as the core drivers of SAM’s label-noise robustness and offers practical avenues for achieving similar benefits with reduced computational cost.

Abstract

Sharpness-Aware Minimization (SAM) is most known for achieving state-of the-art performances on natural image and language tasks. However, its most pronounced improvements (of tens of percent) is rather in the presence of label noise. Understanding SAM's label noise robustness requires a departure from characterizing the robustness of minimas lying in "flatter" regions of the loss landscape. In particular, the peak performance under label noise occurs with early stopping, far before the loss converges. We decompose SAM's robustness into two effects: one induced by changes to the logit term and the other induced by changes to the network Jacobian. The first can be observed in linear logistic regression where SAM provably up-weights the gradient contribution from clean examples. Although this explicit up-weighting is also observable in neural networks, when we intervene and modify SAM to remove this effect, surprisingly, we see no visible degradation in performance. We infer that SAM's effect in deeper networks is instead explained entirely by the effect SAM has on the network Jacobian. We theoretically derive the implicit regularization induced by this Jacobian effect in two layer linear networks. Motivated by our analysis, we see that cheaper alternatives to SAM that explicitly induce these regularization effects largely recover the benefits in deep networks trained on real-world datasets.

Why is SAM Robust to Label Noise?

TL;DR

This work tackles why Sharpness Aware Minimization (SAM) is unusually robust to label noise, especially under early stopping. It decouples the per-example gradient into a logit-scale term and a network Jacobian, showing that while the logit-scale can up-weight clean examples in linear settings, the nonlinear robustness primarily stems from perturbations to the Jacobian, which effectively regularize the network’s activations and last-layer weights. A two-layer deep linear network analysis demonstrates that Jacobian-based SAM updates behave like SGD with L2-type penalties on intermediate activations and final-layer weights, providing an interpretable regularization mechanism. Empirically, Jacobian-focused variants (J-SAM) recover most of SAM’s gains across architectures, and simple regularized SGD can closely approximate SAM’s performance under label noise, suggesting cheaper alternatives that capture the same robustness. Overall, the paper shifts the explanation from convergence-sharpness to optimization trajectory and Jacobian regularization as the core drivers of SAM’s label-noise robustness and offers practical avenues for achieving similar benefits with reduced computational cost.

Abstract

Sharpness-Aware Minimization (SAM) is most known for achieving state-of the-art performances on natural image and language tasks. However, its most pronounced improvements (of tens of percent) is rather in the presence of label noise. Understanding SAM's label noise robustness requires a departure from characterizing the robustness of minimas lying in "flatter" regions of the loss landscape. In particular, the peak performance under label noise occurs with early stopping, far before the loss converges. We decompose SAM's robustness into two effects: one induced by changes to the logit term and the other induced by changes to the network Jacobian. The first can be observed in linear logistic regression where SAM provably up-weights the gradient contribution from clean examples. Although this explicit up-weighting is also observable in neural networks, when we intervene and modify SAM to remove this effect, surprisingly, we see no visible degradation in performance. We infer that SAM's effect in deeper networks is instead explained entirely by the effect SAM has on the network Jacobian. We theoretically derive the implicit regularization induced by this Jacobian effect in two layer linear networks. Motivated by our analysis, we see that cheaper alternatives to SAM that explicitly induce these regularization effects largely recover the benefits in deep networks trained on real-world datasets.
Paper Structure (38 sections, 2 theorems, 24 equations, 14 figures, 1 table)

This paper contains 38 sections, 2 theorems, 24 equations, 14 figures, 1 table.

Key Result

Lemma 3.1

Consider the following function. This function is strictly increasing if $C > 0$.

Figures (14)

  • Figure 1: CIFAR10 training accuracy and loss in clean versus noisy data. SAM achives a higher clean training accuracy before fitting the noisy data, i.e., when accuracy of noisy training data surpasses random chance. This corresponds with a higher peak in test accuracy.
  • Figure 2: SAM learns clean examples faster than noisy examples.
  • Figure 3: Linear models trained on the toy Gaussian data using SAM. SAM's preferential up-weighting of low loss points (right) corresponds with higher early stopping test accuracy (left).
  • Figure 4: In 2-layer deep linear networks (DLN), 2-layer MLP with ReLU activation (2NN), and ResNet18 trained on noisy CIFAR10, we observe that SAM's perturbation to the logit scale preferentially upweights the gradient norm for clean examples (top row). Yet, J-SAM i.e. SAM absent the explicit reweighting effect, preserves SAM's label noise robustness (bottom row).
  • Figure 5: When training ResNet18 with SAM, the norm of the final intermediate activations and last layer weights decreases significantly, consistent with 2-layer DLN analysis.
  • ...and 9 more figures

Theorems & Definitions (2)

  • Lemma 3.1: Preferential up-weighting of low loss points
  • Proposition 4.1