Table of Contents
Fetching ...

Self-Stabilization: The Implicit Bias of Gradient Descent at the Edge of Stability

Alex Damian, Eshaan Nichani, Jason D. Lee

TL;DR

This paper uncovers a generic self-stabilization mechanism in gradient descent at the edge of stability, where the top Hessian direction becomes unstable around the break-even sharpness $S(\theta)=2/\eta$. By a cubic Taylor expansion, the authors show a negative feedback via the term $\nabla^3 L(\theta)(u,u) \approx \nabla S(\theta)$ that curtails sharpness growth and couples the GD trajectory to a projected gradient descent trajectory under the constraint $S(\theta)\le 2/\eta$. They formulate precise predicted dynamics for a reduced 2D system in $(x,y)$, prove a Coupling Theorem that GD tracks the constrained PGD up to higher-order errors, and validate these predictions empirically across standard architectures and datasets. The work also discusses generalized dynamics when the top-direction loss is non-quadratic, connects to practical training considerations (large learning rates, SAM, warmup, weight decay), and outlines avenues for extending the theory to multiple unstable eigenvalues and stochastic optimization. Overall, the results provide a concrete mechanism for gradient descent's implicit bias toward stability and explain the non-monotonic yet overall decreasing loss observed in EOS regimes.

Abstract

Traditional analyses of gradient descent show that when the largest eigenvalue of the Hessian, also known as the sharpness $S(θ)$, is bounded by $2/η$, training is "stable" and the training loss decreases monotonically. Recent works, however, have observed that this assumption does not hold when training modern neural networks with full batch or large batch gradient descent. Most recently, Cohen et al. (2021) observed two important phenomena. The first, dubbed progressive sharpening, is that the sharpness steadily increases throughout training until it reaches the instability cutoff $2/η$. The second, dubbed edge of stability, is that the sharpness hovers at $2/η$ for the remainder of training while the loss continues decreasing, albeit non-monotonically. We demonstrate that, far from being chaotic, the dynamics of gradient descent at the edge of stability can be captured by a cubic Taylor expansion: as the iterates diverge in direction of the top eigenvector of the Hessian due to instability, the cubic term in the local Taylor expansion of the loss function causes the curvature to decrease until stability is restored. This property, which we call self-stabilization, is a general property of gradient descent and explains its behavior at the edge of stability. A key consequence of self-stabilization is that gradient descent at the edge of stability implicitly follows projected gradient descent (PGD) under the constraint $S(θ) \le 2/η$. Our analysis provides precise predictions for the loss, sharpness, and deviation from the PGD trajectory throughout training, which we verify both empirically in a number of standard settings and theoretically under mild conditions. Our analysis uncovers the mechanism for gradient descent's implicit bias towards stability.

Self-Stabilization: The Implicit Bias of Gradient Descent at the Edge of Stability

TL;DR

This paper uncovers a generic self-stabilization mechanism in gradient descent at the edge of stability, where the top Hessian direction becomes unstable around the break-even sharpness . By a cubic Taylor expansion, the authors show a negative feedback via the term that curtails sharpness growth and couples the GD trajectory to a projected gradient descent trajectory under the constraint . They formulate precise predicted dynamics for a reduced 2D system in , prove a Coupling Theorem that GD tracks the constrained PGD up to higher-order errors, and validate these predictions empirically across standard architectures and datasets. The work also discusses generalized dynamics when the top-direction loss is non-quadratic, connects to practical training considerations (large learning rates, SAM, warmup, weight decay), and outlines avenues for extending the theory to multiple unstable eigenvalues and stochastic optimization. Overall, the results provide a concrete mechanism for gradient descent's implicit bias toward stability and explain the non-monotonic yet overall decreasing loss observed in EOS regimes.

Abstract

Traditional analyses of gradient descent show that when the largest eigenvalue of the Hessian, also known as the sharpness , is bounded by , training is "stable" and the training loss decreases monotonically. Recent works, however, have observed that this assumption does not hold when training modern neural networks with full batch or large batch gradient descent. Most recently, Cohen et al. (2021) observed two important phenomena. The first, dubbed progressive sharpening, is that the sharpness steadily increases throughout training until it reaches the instability cutoff . The second, dubbed edge of stability, is that the sharpness hovers at for the remainder of training while the loss continues decreasing, albeit non-monotonically. We demonstrate that, far from being chaotic, the dynamics of gradient descent at the edge of stability can be captured by a cubic Taylor expansion: as the iterates diverge in direction of the top eigenvector of the Hessian due to instability, the cubic term in the local Taylor expansion of the loss function causes the curvature to decrease until stability is restored. This property, which we call self-stabilization, is a general property of gradient descent and explains its behavior at the edge of stability. A key consequence of self-stabilization is that gradient descent at the edge of stability implicitly follows projected gradient descent (PGD) under the constraint . Our analysis provides precise predictions for the loss, sharpness, and deviation from the PGD trajectory throughout training, which we verify both empirically in a number of standard settings and theoretically under mild conditions. Our analysis uncovers the mechanism for gradient descent's implicit bias towards stability.
Paper Structure (55 sections, 22 theorems, 141 equations, 6 figures)

This paper contains 55 sections, 22 theorems, 141 equations, 6 figures.

Key Result

Lemma 1

Assume that $S(\theta) \le \ell$ for all $\theta$. If $\theta_{t+1} = \theta_t - \eta \nabla L(\theta_t)$,

Figures (6)

  • Figure 1: Progressive Sharpening and Edge of Stability: We train an MLP on CIFAR10 with learning rate $\eta = 2/100$. It reaches instability after around $2200$ training steps after which the sharpness hovers at $2/\eta = 100$, which is denoted by the horizontal dashed line.
  • Figure 2: The four stages of edge of stability (see \ref{['sec:four_stages']}), demonstrated on a simple loss function (see \ref{['sec:toy_model']}).
  • Figure 3: The effect of $\mathbf{X(0)}$ (left): We plot the evolution of the ODE in \ref{['eq:bean_ode']} with $\alpha = \beta = 1$ for varying $X(0)$. Observe that smaller $X(0)$'s correspond to larger curves. The four stages of edge of stability (right): We show how the four stages of edge of stability described in \ref{['sec:four_stages']} and \ref{['fig:stages']} correspond to different parts of the curve generated by the ODE in \ref{['eq:bean_ode']}.
  • Figure 4: We empirically demonstrate that the predicted dynamics given by \ref{['eq:predicted_x_y_only']} track the true dynamics of gradient descent at the edge of stability. For each learning rate, the top row is a zoomed in version of the bottom row which isolates one cycle and is reflected by the dashed rectangle in the bottom row. Reported sharpnesses are two-step averages for visual clarity. For additional experimental details, see \ref{['sec:experiments']} and \ref{['sec:experimental_details']}.
  • Figure 5: Edge of stability with multiple unstable eigenvalues. Each vertical line is the time at which the corresponding eigenvalue of the same color becomes unstable.
  • ...and 1 more figures

Theorems & Definitions (48)

  • Definition 1
  • Lemma 1: Descent Lemma
  • Lemma 2: Self-Stabilization Property
  • Definition 2: Progressive Sharpening Coefficient
  • Lemma 3
  • proof
  • Corollary 1
  • Definition 3: Taylor Expansion Quantities at $\theta^\dagger_t$
  • Definition 4
  • Definition 5
  • ...and 38 more