How to escape sharp minima with random perturbations
Kwangjun Ahn, Ali Jadbabaie, Suvrit Sra
TL;DR
The paper formalizes flat minima by introducing the normalized Hessian-trace measure $\overline{\mathrm{tr}}(x)=\frac{\mathrm{tr}(\nabla^2 f(x))}{d}$ and a gradient-flow-based notion of flat minima via $\Phi$. It then presents two gradient-based algorithms to efficiently find approximate flat minima: (i) Randomly Smoothed Perturbation (RS), which uses perturbed-gradient steps to estimate a descent direction for $\overline{\mathrm{tr}}$ and achieves an $(\epsilon,\sqrt{\epsilon})$-flat minimum in $T=\mathcal{O}(\epsilon^{-3})$ iterations (up to constants and probability $1-\mathcal{O}(\delta)$); and (ii) Sharpness-Aware Perturbation (SAM), a data-driven variant inspired by SAM that attains the same flatness goal with a faster rate in high dimensions, specifically $T=\mathcal{O}\left( d^{-1}\epsilon^{-2} \max\{1,\frac{1}{d^3\epsilon}\}\delta^{-4}\right)$ iterations, followed by gradient descent to finish. Theoretical analyses show that random perturbations enable estimation of higher-order curvature information through gradients, driving downward the trace of the Hessian, while stochastic gradients in SAM further accelerate this descent. Experiments on CIFAR-10 with ResNet-18 corroborate the theory, showing that SAM more effectively escapes sharp minima than RS, even with small perturbation radii and varying batch sizes, aligning with the improved exploration of the flat-minima landscape. Overall, the work provides formal definitions, complexity bounds, and empirically validated algorithms that connect gradient-based optimization to the pursuit of flat minima with practical implications for generalization in deep learning.
Abstract
Modern machine learning applications have witnessed the remarkable success of optimization algorithms that are designed to find flat minima. Motivated by this design choice, we undertake a formal study that (i) formulates the notion of flat minima, and (ii) studies the complexity of finding them. Specifically, we adopt the trace of the Hessian of the cost function as a measure of flatness, and use it to formally define the notion of approximate flat minima. Under this notion, we then analyze algorithms that find approximate flat minima efficiently. For general cost functions, we discuss a gradient-based algorithm that finds an approximate flat local minimum efficiently. The main component of the algorithm is to use gradients computed from randomly perturbed iterates to estimate a direction that leads to flatter minima. For the setting where the cost function is an empirical risk over training data, we present a faster algorithm that is inspired by a recently proposed practical algorithm called sharpness-aware minimization, supporting its success in practice.
