Table of Contents
Fetching ...

Product-Stability: Provable Convergence for Gradient Descent on the Edge of Stability

Eric Gan

Abstract

Empirically, modern deep learning training often occurs at the Edge of Stability (EoS), where the sharpness of the loss exceeds the threshold below which classical convergence analysis applies. Despite recent progress, existing theoretical explanations of EoS either rely on restrictive assumptions or focus on specific squared-loss-type objectives. In this work, we introduce and study a structural property of loss functions that we term product-stability. We show that for losses with product-stable minima, gradient descent applied to objectives of the form $(x,y) \mapsto l(xy)$ can provably converge to the local minimum even when training in the EoS regime. This framework substantially generalizes prior results and applies to a broad class of losses, including binary cross entropy. Using bifurcation diagrams, we characterize the resulting training dynamics, explain the emergence of stable oscillations, and precisely quantify the sharpness at convergence. Together, our results offer a principled explanation for stable EoS training for a wider class of loss functions.

Product-Stability: Provable Convergence for Gradient Descent on the Edge of Stability

Abstract

Empirically, modern deep learning training often occurs at the Edge of Stability (EoS), where the sharpness of the loss exceeds the threshold below which classical convergence analysis applies. Despite recent progress, existing theoretical explanations of EoS either rely on restrictive assumptions or focus on specific squared-loss-type objectives. In this work, we introduce and study a structural property of loss functions that we term product-stability. We show that for losses with product-stable minima, gradient descent applied to objectives of the form can provably converge to the local minimum even when training in the EoS regime. This framework substantially generalizes prior results and applies to a broad class of losses, including binary cross entropy. Using bifurcation diagrams, we characterize the resulting training dynamics, explain the emergence of stable oscillations, and precisely quantify the sharpness at convergence. Together, our results offer a principled explanation for stable EoS training for a wider class of loss functions.

Paper Structure

This paper contains 35 sections, 23 theorems, 99 equations, 6 figures.

Key Result

Lemma 3.1

The trace of the Hessian of $\mathcal{L}$ is given by $\mathop{\mathrm{tr}}\nolimits(\nabla^2 \mathcal{L}) = l"(z) s$. If $l'(z) = 0$, then the sharpness of $\mathcal{L}$ is also given by $\lambda = \mathop{\mathrm{tr}}\nolimits( \nabla^2 \mathcal{L}) = l"(z) s$. $\blacktriangleleft$$\blacktriangle

Figures (6)

  • Figure 1: EoS Dynamics in the xy Plane. Iterates start on the right near a high sharpness minima. They quickly diverge away from the sharp minima before drifting towards a flatter minima on the left.
  • Figure 3: EoS training dynamics for $l = \text{MLSq}_{1,2}$ (\ref{['eq:mlsq_definition']}) when started very close to the EoS threshold. The iterates do not converge to the final sharpness predicted by \ref{['thm:final_sharpness']}, showing that the $\delta$ gap is required.
  • Figure 4: End of training dynamics for the runs in \ref{['fig:eos_training_dynamics']}. One can observe Phase III of the training dynamics, where the iterates converge to $z_*$. The limiting sharpness is just below the EoS threshold and very close to the value predicted by \ref{['thm:final_sharpness']}.
  • Figure 5: Training dynamics of fully-connected tanh network on CIFAR-10. a) shows the training loss, which consistently decreases over long timescales. The loss is also oscillating while in the EoS regime, see \ref{['fig:oscillation']} for zoomed in view. b) shows the sharpness, with the dotted lines representing the EoS threshold. Training enters the EoS regime where the sharpness oscillates around the EoS threshold. c) shows the product-stability calculated using directional derivatives along the direction of maximal sharpness. The product-stability is positive, indicating stability.
  • Figure 6: Loss Oscillation in EoS Regime. Zoomed in version of \ref{['fig:a']}, showing that the loss is oscillating in the EoS regime.
  • ...and 1 more figures

Theorems & Definitions (44)

  • Lemma 3.1
  • Definition 4.1
  • Theorem 4.2
  • Theorem 4.3
  • Lemma 5.1
  • Theorem 5.2
  • Lemma 6.1
  • Lemma 6.2
  • Lemma 6.4
  • Lemma 6.5
  • ...and 34 more