Table of Contents
Fetching ...

Edge of Stochastic Stability: Revisiting the Edge of Stability for SGD

Arseniy Andreyev, Pierfrancesco Beneventano

TL;DR

The paper addresses why mini-batch SGD traverses neural loss landscapes in a regime that differs from the full-batch Edge of Stability (EoS). It introduces Batch Sharpness as the directionally aware mini-batch curvature that saturates near $2/\eta$, defining Edge of Stochastic Stability (EoSS) and proving that when Batch Sharpness exceeds $(2+\epsilon)/\eta$ on the quadratic approximation, SGD experiences a catapult, i.e., divergence on that local model. Empirically across vision tasks with MSE, Batch Sharpness progressively sharpens and then stabilizes at $2/\eta$ regardless of batch size, while $\lambda_{\max}$ plateaus at a lower level, depending on history and batch size; this accounts for smaller-batch advantages in attaining flatter minima. The work further argues that classical optimization and SDE-based models fail to capture EoSS because they discard mini-batch geometry and higher-order curvature statistics, emphasizing the need to model batch-dependent Hessians and their alignments. Overall, EoSS unifies observed phenomena (catapults, progressive sharpening, batch-size effects) under a stability criterion driven by mini-batch geometry, with significant implications for optimization theory and the modeling of SGD trajectories.

Abstract

Recent findings by Cohen et al., 2021, demonstrate that when training neural networks using full-batch gradient descent with a step size of $η$, the largest eigenvalue $λ_{\max}$ of the full-batch Hessian consistently stabilizes around $2/η$. These results have significant implications for convergence and generalization. This, however, is not the case for mini-batch optimization algorithms, limiting the broader applicabilityof the consequences of these findings. We show mini-batch Stochastic Gradient Descent (SGD) trains in a different regime we term Edge of Stochastic Stability (EoSS). In this regime, what stabilizes at $2/η$ is Batch Sharpness: the expected directional curvature of mini-batch Hessians along their corresponding stochastic gradients. As a consequence $λ_{\max}$ -- which is generally smaller than Batch Sharpness -- is suppressed, aligning with the long-standing empirical observation that smaller batches and larger step sizes favor flatter minima. We further discuss implications for mathematical modeling of SGD trajectories.

Edge of Stochastic Stability: Revisiting the Edge of Stability for SGD

TL;DR

The paper addresses why mini-batch SGD traverses neural loss landscapes in a regime that differs from the full-batch Edge of Stability (EoS). It introduces Batch Sharpness as the directionally aware mini-batch curvature that saturates near , defining Edge of Stochastic Stability (EoSS) and proving that when Batch Sharpness exceeds on the quadratic approximation, SGD experiences a catapult, i.e., divergence on that local model. Empirically across vision tasks with MSE, Batch Sharpness progressively sharpens and then stabilizes at regardless of batch size, while plateaus at a lower level, depending on history and batch size; this accounts for smaller-batch advantages in attaining flatter minima. The work further argues that classical optimization and SDE-based models fail to capture EoSS because they discard mini-batch geometry and higher-order curvature statistics, emphasizing the need to model batch-dependent Hessians and their alignments. Overall, EoSS unifies observed phenomena (catapults, progressive sharpening, batch-size effects) under a stability criterion driven by mini-batch geometry, with significant implications for optimization theory and the modeling of SGD trajectories.

Abstract

Recent findings by Cohen et al., 2021, demonstrate that when training neural networks using full-batch gradient descent with a step size of , the largest eigenvalue of the full-batch Hessian consistently stabilizes around . These results have significant implications for convergence and generalization. This, however, is not the case for mini-batch optimization algorithms, limiting the broader applicabilityof the consequences of these findings. We show mini-batch Stochastic Gradient Descent (SGD) trains in a different regime we term Edge of Stochastic Stability (EoSS). In this regime, what stabilizes at is Batch Sharpness: the expected directional curvature of mini-batch Hessians along their corresponding stochastic gradients. As a consequence -- which is generally smaller than Batch Sharpness -- is suppressed, aligning with the long-standing empirical observation that smaller batches and larger step sizes favor flatter minima. We further discuss implications for mathematical modeling of SGD trajectories.
Paper Structure (132 sections, 15 theorems, 175 equations, 36 figures)

This paper contains 132 sections, 15 theorems, 175 equations, 36 figures.

Key Result

Lemma 1

Assume $h_0$ saturates $f$ at $\theta_t$. Under assumptions (a)–(c) above, any sufficiently small destabilizing perturbation of $h_0$ (e.g. $\eta\uparrow$ or $b\downarrow$) produces hyperparameters $h$ such that $f(\theta_t;h)>c(h)$. By validity of the criterion, the quadratic‑model trajectory from

Figures (36)

  • Figure 1: SGD at EoSS under different step sizes and batch sizes. MLP on an 8k subset of CIFAR-10 with step size $\eta$. Batch Sharpness stabilizes at the $2/\eta$ threshold across varying batch sizes and step sizes.
  • Figure 2: SGD on CIFAR-10: $\eta = 1/400$. The full-batch Hessian’s $\lambda_{\max}$ plateaus below$2/\eta$. Smaller batch sizes lead to lower plateau values.
  • Figure 3: Catapults at EoSS. During EoSS, randomness in batch sampling might cause catapults, leading to renewed PS, and EoSS again. Notation follows Fig. \ref{['fig:eoss']}.
  • Figure 4: Comparing different sharpness measures.Red: step sharpness, observed sharpness on the current step's mini-batch---essentially Batch Sharpness without the expectation; Green: Batch Sharpness (Definition \ref{['def:minibs']}); Blue: full-batch $\lambda_{\max}$. Top row: MLP (2 hidden layers of width 512); middle: 5-layer CNN; bottom: ResNet-14; all trained on an 8k subset of CIFAR-10.
  • Figure 5: (1) The whole training happens with Type-1 oscillations (see Proposition \ref{['prop:lee-jang']}, GNI$\approx 2/\eta$), however, (2)GNI being $2/\eta$ does not govern Type-2 oscillations---in particular, highlighting the difference between the two types of oscillations. (3)Batch Sharpness is instead an indicator of Type-2 oscillations, as illustrated by the fact that catapults happen only when the shift in hyperparameters occurs afterBatch Sharpness reaches $2/\eta$.
  • ...and 31 more figures

Theorems & Definitions (30)

  • Definition 1: Instability criterion
  • Definition 2: Catapults on the quadratic model
  • Lemma 1: Tight instability criterion $\Rightarrow$ catapult under perturbation
  • Definition 3: Batch Sharpness
  • Definition 4: Oscillatory time window
  • Definition 5
  • Lemma 2: GNI is a certificate of oscillations
  • Proposition 1
  • Theorem 1: Batch Sharpness is an instability criterion
  • Proposition 2: Loss increment and GNI
  • ...and 20 more