Edge of Stochastic Stability: Revisiting the Edge of Stability for SGD
Arseniy Andreyev, Pierfrancesco Beneventano
TL;DR
The paper addresses why mini-batch SGD traverses neural loss landscapes in a regime that differs from the full-batch Edge of Stability (EoS). It introduces Batch Sharpness as the directionally aware mini-batch curvature that saturates near $2/\eta$, defining Edge of Stochastic Stability (EoSS) and proving that when Batch Sharpness exceeds $(2+\epsilon)/\eta$ on the quadratic approximation, SGD experiences a catapult, i.e., divergence on that local model. Empirically across vision tasks with MSE, Batch Sharpness progressively sharpens and then stabilizes at $2/\eta$ regardless of batch size, while $\lambda_{\max}$ plateaus at a lower level, depending on history and batch size; this accounts for smaller-batch advantages in attaining flatter minima. The work further argues that classical optimization and SDE-based models fail to capture EoSS because they discard mini-batch geometry and higher-order curvature statistics, emphasizing the need to model batch-dependent Hessians and their alignments. Overall, EoSS unifies observed phenomena (catapults, progressive sharpening, batch-size effects) under a stability criterion driven by mini-batch geometry, with significant implications for optimization theory and the modeling of SGD trajectories.
Abstract
Recent findings by Cohen et al., 2021, demonstrate that when training neural networks using full-batch gradient descent with a step size of $η$, the largest eigenvalue $λ_{\max}$ of the full-batch Hessian consistently stabilizes around $2/η$. These results have significant implications for convergence and generalization. This, however, is not the case for mini-batch optimization algorithms, limiting the broader applicabilityof the consequences of these findings. We show mini-batch Stochastic Gradient Descent (SGD) trains in a different regime we term Edge of Stochastic Stability (EoSS). In this regime, what stabilizes at $2/η$ is Batch Sharpness: the expected directional curvature of mini-batch Hessians along their corresponding stochastic gradients. As a consequence $λ_{\max}$ -- which is generally smaller than Batch Sharpness -- is suppressed, aligning with the long-standing empirical observation that smaller batches and larger step sizes favor flatter minima. We further discuss implications for mathematical modeling of SGD trajectories.
