Table of Contents
Fetching ...

How to escape sharp minima with random perturbations

Kwangjun Ahn, Ali Jadbabaie, Suvrit Sra

TL;DR

The paper formalizes flat minima by introducing the normalized Hessian-trace measure $\overline{\mathrm{tr}}(x)=\frac{\mathrm{tr}(\nabla^2 f(x))}{d}$ and a gradient-flow-based notion of flat minima via $\Phi$. It then presents two gradient-based algorithms to efficiently find approximate flat minima: (i) Randomly Smoothed Perturbation (RS), which uses perturbed-gradient steps to estimate a descent direction for $\overline{\mathrm{tr}}$ and achieves an $(\epsilon,\sqrt{\epsilon})$-flat minimum in $T=\mathcal{O}(\epsilon^{-3})$ iterations (up to constants and probability $1-\mathcal{O}(\delta)$); and (ii) Sharpness-Aware Perturbation (SAM), a data-driven variant inspired by SAM that attains the same flatness goal with a faster rate in high dimensions, specifically $T=\mathcal{O}\left( d^{-1}\epsilon^{-2} \max\{1,\frac{1}{d^3\epsilon}\}\delta^{-4}\right)$ iterations, followed by gradient descent to finish. Theoretical analyses show that random perturbations enable estimation of higher-order curvature information through gradients, driving downward the trace of the Hessian, while stochastic gradients in SAM further accelerate this descent. Experiments on CIFAR-10 with ResNet-18 corroborate the theory, showing that SAM more effectively escapes sharp minima than RS, even with small perturbation radii and varying batch sizes, aligning with the improved exploration of the flat-minima landscape. Overall, the work provides formal definitions, complexity bounds, and empirically validated algorithms that connect gradient-based optimization to the pursuit of flat minima with practical implications for generalization in deep learning.

Abstract

Modern machine learning applications have witnessed the remarkable success of optimization algorithms that are designed to find flat minima. Motivated by this design choice, we undertake a formal study that (i) formulates the notion of flat minima, and (ii) studies the complexity of finding them. Specifically, we adopt the trace of the Hessian of the cost function as a measure of flatness, and use it to formally define the notion of approximate flat minima. Under this notion, we then analyze algorithms that find approximate flat minima efficiently. For general cost functions, we discuss a gradient-based algorithm that finds an approximate flat local minimum efficiently. The main component of the algorithm is to use gradients computed from randomly perturbed iterates to estimate a direction that leads to flatter minima. For the setting where the cost function is an empirical risk over training data, we present a faster algorithm that is inspired by a recently proposed practical algorithm called sharpness-aware minimization, supporting its success in practice.

How to escape sharp minima with random perturbations

TL;DR

The paper formalizes flat minima by introducing the normalized Hessian-trace measure and a gradient-flow-based notion of flat minima via . It then presents two gradient-based algorithms to efficiently find approximate flat minima: (i) Randomly Smoothed Perturbation (RS), which uses perturbed-gradient steps to estimate a descent direction for and achieves an -flat minimum in iterations (up to constants and probability ); and (ii) Sharpness-Aware Perturbation (SAM), a data-driven variant inspired by SAM that attains the same flatness goal with a faster rate in high dimensions, specifically iterations, followed by gradient descent to finish. Theoretical analyses show that random perturbations enable estimation of higher-order curvature information through gradients, driving downward the trace of the Hessian, while stochastic gradients in SAM further accelerate this descent. Experiments on CIFAR-10 with ResNet-18 corroborate the theory, showing that SAM more effectively escapes sharp minima than RS, even with small perturbation radii and varying batch sizes, aligning with the improved exploration of the flat-minima landscape. Overall, the work provides formal definitions, complexity bounds, and empirically validated algorithms that connect gradient-based optimization to the pursuit of flat minima with practical implications for generalization in deep learning.

Abstract

Modern machine learning applications have witnessed the remarkable success of optimization algorithms that are designed to find flat minima. Motivated by this design choice, we undertake a formal study that (i) formulates the notion of flat minima, and (ii) studies the complexity of finding them. Specifically, we adopt the trace of the Hessian of the cost function as a measure of flatness, and use it to formally define the notion of approximate flat minima. Under this notion, we then analyze algorithms that find approximate flat minima efficiently. For general cost functions, we discuss a gradient-based algorithm that finds an approximate flat local minimum efficiently. The main component of the algorithm is to use gradients computed from randomly perturbed iterates to estimate a direction that leads to flatter minima. For the setting where the cost function is an empirical risk over training data, we present a faster algorithm that is inspired by a recently proposed practical algorithm called sharpness-aware minimization, supporting its success in practice.
Paper Structure (33 sections, 14 theorems, 82 equations, 3 figures, 1 table, 2 algorithms)

This paper contains 33 sections, 14 theorems, 82 equations, 3 figures, 1 table, 2 algorithms.

Key Result

Theorem 1

Let as:local hold and $f$ have $\beta$-Lipschitz gradients. Let the target accuracy $\epsilon>0$ be chosen sufficiently small, and $\delta\in(0,1)$. Suppose that ${\bm x}_0$ is $\zeta$-close to $\mathcal{X}^\star$. Then, the randomly smoothed perturbation algorithm (algo:RS) with parameters $\eta =\

Figures (3)

  • Figure 1: Figure from liu2023same. They pretrain language models for probabilistic context-free grammar with different optimization methods, and compare their downstream accuracy. As shown in the plot, the trace of Hessian is a better indicator of the performance than the pretraining loss itself.
  • Figure 2: Figure from damian2021label. For training ResNet-18 on CIFAR 10, they measure the trace of Hessian across the iterates of SGD with label noise and have observed an inspiring relation between $\mathop{\mathrm{tr}}\nolimits(\nabla^2 f({\bm x}_t))$ and prediction performance.
  • Figure 3: (Left) The comparison between Randomly Smoothed Perturbation ("RS") and Sharpness-Aware Perturbation ("SA"). (Right) Comparison of SA with different batch sizes. Here, we highlight that we do observe that the trace of Hessian value monotonically decreases along the algorithm iterates, similarly to damian2021label (see also \ref{['fig:damian']}). We decide to present the test accuracy instead of the trace of Hessian, as it has more practical values.

Theorems & Definitions (29)

  • Remark 1: Other notions of flatness?
  • Example 1
  • Definition 1: Limit point under gradient flow
  • Definition 2: Flat local minima
  • Definition 3: $(\epsilon,\epsilon')$-flat local minima
  • Definition 4: Projecting-out operator
  • Theorem 1
  • Lemma 1
  • Lemma 2
  • Lemma 3
  • ...and 19 more