Table of Contents
Fetching ...

Monge SAM: Robust Reparameterization-Invariant Sharpness-Aware Minimization Based on Loss Geometry

Albert Kjøller Jacobsen, Georgios Arvanitidis

TL;DR

Monge SAM introduces a reparameterization-invariant sharpness-aware minimization by defining a Monge metric $\mathbf{G}(\boldsymbol{\theta}) = \mathbf{I}_K + \nabla \ell(\boldsymbol{\theta}) \nabla \ell(\boldsymbol{\theta})^\top$ and deriving a closed-form worst-case perturbation $\boldsymbol{\delta}_{\text{M-SAM}}^* = \frac{1}{\sqrt{1+\|\nabla \ell(\boldsymbol{\theta})\|_2^2}} \cdot \big(\frac{\rho}{\|\nabla \ell(\boldsymbol{\theta})\|_2} \nabla \ell(\boldsymbol{\theta})\big)$. This makes M-SAM interpolate between SGD and SAM, yielding robustness to hyperparameters and reduced saddle-point attraction while remaining applicable to any modeling choice. The authors provide theoretical stability analyses and empirical demonstrations on toy 2D losses, CIFAR-10 fine-tuning, and cross-modal CLIP alignment, showing improved stability and representational alignment over SAM in several settings. Overall, M-SAM offers a geometry-aware alternative to SAM that leverages loss geometry to achieve reparameterization invariance with practical benefits for generalization and robustness.

Abstract

Recent studies on deep neural networks show that flat minima of the loss landscape correlate with improved generalization. Sharpness-aware minimization (SAM) efficiently finds flat regions by updating the parameters according to the gradient at an adversarial perturbation. The perturbation depends on the Euclidean metric, making SAM non-invariant under reparametrizations, which blurs sharpness and generalization. We propose Monge SAM (M-SAM), a reparametrization invariant version of SAM by considering a Riemannian metric in the parameter space induced naturally by the loss surface. Compared to previous approaches, M-SAM works under any modeling choice, relies only on mild assumptions while being as computationally efficient as SAM. We theoretically argue that M-SAM varies between SAM and gradient descent (GD), which increases robustness to hyperparameter selection and reduces attraction to suboptimal equilibria like saddle points. We demonstrate this behavior both theoretically and empirically on a multi-modal representation alignment task.

Monge SAM: Robust Reparameterization-Invariant Sharpness-Aware Minimization Based on Loss Geometry

TL;DR

Monge SAM introduces a reparameterization-invariant sharpness-aware minimization by defining a Monge metric and deriving a closed-form worst-case perturbation . This makes M-SAM interpolate between SGD and SAM, yielding robustness to hyperparameters and reduced saddle-point attraction while remaining applicable to any modeling choice. The authors provide theoretical stability analyses and empirical demonstrations on toy 2D losses, CIFAR-10 fine-tuning, and cross-modal CLIP alignment, showing improved stability and representational alignment over SAM in several settings. Overall, M-SAM offers a geometry-aware alternative to SAM that leverages loss geometry to achieve reparameterization invariance with practical benefits for generalization and robustness.

Abstract

Recent studies on deep neural networks show that flat minima of the loss landscape correlate with improved generalization. Sharpness-aware minimization (SAM) efficiently finds flat regions by updating the parameters according to the gradient at an adversarial perturbation. The perturbation depends on the Euclidean metric, making SAM non-invariant under reparametrizations, which blurs sharpness and generalization. We propose Monge SAM (M-SAM), a reparametrization invariant version of SAM by considering a Riemannian metric in the parameter space induced naturally by the loss surface. Compared to previous approaches, M-SAM works under any modeling choice, relies only on mild assumptions while being as computationally efficient as SAM. We theoretically argue that M-SAM varies between SAM and gradient descent (GD), which increases robustness to hyperparameter selection and reduces attraction to suboptimal equilibria like saddle points. We demonstrate this behavior both theoretically and empirically on a multi-modal representation alignment task.

Paper Structure

This paper contains 22 sections, 18 equations, 7 figures, 3 tables.

Figures (7)

  • Figure 1: The SAM finds the adversarial perturbation within a Euclidean ball ( ) which upper bounds the M-SAM perturbation that is based on the local geometry of the loss ( ), implying an adaptive trade-off between SAM and GD. In a loss defined by $\ell\left(\boldsymbol{\theta}\right) = \left(1-\theta_1\theta_2\right)^2$ with banana-shaped minima at $\theta_1\!=\!{1}/{\theta_2}$, M-SAM is less prone to get attracted to the saddle point at $\boldsymbol{\theta}_s = \left(0,0\right)$ than SAM. M-SAM can reach lower losses like GD while being capable of walking along minima, eventually finding the flattest global minimum at $\boldsymbol{\theta}^\ast_{\text{flat}} = \left(-1,-1\right)$. We run 200 steps from $\boldsymbol{\theta}_0=(-\frac{3}{2}, \frac{1}{2})$ with $\rho=1$ and a learning rate of $0.01$. Arrows represent the respective gradients (rescaled) at the perturbed points.
  • Figure 2: Attraction to maxima. We consider the scaled 2D $\mathrm{sinc}$-function given by $\mathrm{sinc}\left(x, y\right) = 5\cdot\sin \left(x^2+y^2\right) / \left(x^2+y^2\right)$ and draw $N=200$ samples of $\boldsymbol{\theta}=(x,y)$, uniformly distributed within the centered unit square. For each sample, we run SAM and M-SAM for 200 steps with a learning rate of 0.05 and $\rho \in \{0.2, 1\}$ and plot the distribution of the converged estimates, $\boldsymbol{\theta}^\ast$. The larger perturbation radius $\rho=1$ makes the global maximum ( ) at $\boldsymbol{\theta}=(0,0)$ a stronger attractor for SAM, as for $\rho=0.2$ SAM is more likely to descend into the surrounding circular minima range ( ). We observe similar trends for M-SAM, yet see that the conservative property restricts how strong attractor the maxima is, even for the high perturbation radius of $\rho=1$.
  • Figure 3: Conservative property. SAM vs. M-SAM behavior in a simple paraboloid loss given by $\ell\left(\theta\right) = \theta^2$ for two learning rates, $\alpha_1 < \alpha_2$ when fixing the perturbation radius $\rho=0.3$ and taking $11$ steps. M-SAM always converges to lower values than SAM and is additionally limitedly affected by large learning rates, revealing M-SAM's conservativeness that originates from being loss-aware.
  • Figure 4: Stability criteria. Intuitively, we highlight the stability of a parameter $\boldsymbol{\theta}$ by considering the behavior of a random sample $\tilde{\boldsymbol{\theta}} = \boldsymbol{\theta} + \boldsymbol{\epsilon}$ within the $\boldsymbol{\epsilon}$-neighborhood of $\boldsymbol{\theta}$. We consider the banana-shaped loss, i.e. $\ell\left(\boldsymbol{\theta}\right) = \left(1-\theta_1\theta_2\right)^2$, and compute the eigenvalues, $\{\lambda_1, \lambda_2\}$, of $\mathbf{A}_{\rho}\left(\boldsymbol{\theta}\right)$. We show $\lambda_2$ across the parameter space for varying perturbation radii for SAM and M-SAM. As $\lambda_1$ is always positive, stability of SAM and M-SAM dynamics is satisfied when $\lambda_2$ is positive (red regions). The square is used for reference. When $\rho=0$, i.e. GD behavior, the $\boldsymbol{\epsilon}$-neighborhood of any point in parameter space will converge to the global minima, revealing that GD does not suffer from saddle point attraction if not initialized exactly at the saddle point, $\boldsymbol{\theta}_s=\left(0,0\right)$. Contrarily, $\boldsymbol{\theta}_s$ able for SAM and M-SAM for higher $\rho$, even so, M-SAM's region of attraction is smaller near the saddle but larger near flat global minima at $\boldsymbol{\theta}_{\text{flat}}^\ast \in \{(-1, -1), (1,1)\}$, due to its conservative nature.
  • Figure 5: Comparing sharpness-aware minimizers. Trajectories of SAM, M-SAM, Fisher SAM and ASAM in the banana-shaped loss given by $\ell\left(\boldsymbol{\theta}\right) = (1-\theta_1\theta_2)^2$. Each optimizer is initialized on the points of a $10\times10$ grid and runs to convergence. The marginal distributions of convergence locations summarize the final estimates. While all methods can explore the valley of minima and locate the flattest solution, SAM exhibits a notable tendency to converge to the saddle point. M-SAM and Fisher SAM show qualitatively similar behavior, with only minor differences when initialized near the saddle. We note that this behavior is due to the diagonal approximation of the Fisher metric. We compare with ASAM for completeness; it mostly avoids the saddle point but converges anywhere on the minima range and does generally not find one of the flattest solution.
  • ...and 2 more figures