Table of Contents
Fetching ...

High dimensional analysis reveals conservative sharpening and a stochastic edge of stability

Atish Agarwala, Jeffrey Pennington

TL;DR

This work analyzes stochastic gradient descent dynamics through the lens of Hessian-spectrum stability, introducing a stochastic edge of stability (S-EOS) governed by a noise kernel norm $\mathcal{K}$. By combining a high-dimensional linearized model with a second-moment analysis, it derives conditions for both deterministic EOS ($\eta\lambda_{\max}<2$) and S-EOS (threshold near $\mathcal{K}=1$), and explains conservative sharpening where SGD noise preferentially dampens large eigenmodes. The paper combines theoretical results with extensive experiments on MNIST, CIFAR-10, and Imagenet, showing that $\mathcal{K}$ is a robust predictor of training outcomes across architectures and loss types, and that controlling SGD noise can improve optimization efficiency. These insights highlight practical opportunities to adapt learning rates and batch sizes by tracking $\mathcal{K}$ and the NTK spectrum, potentially guiding more stable and faster training in large-scale models.

Abstract

Recent empirical and theoretical work has shown that the dynamics of the large eigenvalues of the training loss Hessian have some remarkably robust features across models and datasets in the full batch regime. There is often an early period of progressive sharpening where the large eigenvalues increase, followed by stabilization at a predictable value known as the edge of stability. Previous work showed that in the stochastic setting, the eigenvalues increase more slowly - a phenomenon we call conservative sharpening. We provide a theoretical analysis of a simple high-dimensional model which shows the origin of this slowdown. We also show that there is an alternative stochastic edge of stability which arises at small batch size that is sensitive to the trace of the Neural Tangent Kernel rather than the large Hessian eigenvalues. We conduct an experimental study which highlights the qualitative differences from the full batch phenomenology, and suggests that controlling the stochastic edge of stability can help optimization.

High dimensional analysis reveals conservative sharpening and a stochastic edge of stability

TL;DR

This work analyzes stochastic gradient descent dynamics through the lens of Hessian-spectrum stability, introducing a stochastic edge of stability (S-EOS) governed by a noise kernel norm . By combining a high-dimensional linearized model with a second-moment analysis, it derives conditions for both deterministic EOS () and S-EOS (threshold near ), and explains conservative sharpening where SGD noise preferentially dampens large eigenmodes. The paper combines theoretical results with extensive experiments on MNIST, CIFAR-10, and Imagenet, showing that is a robust predictor of training outcomes across architectures and loss types, and that controlling SGD noise can improve optimization efficiency. These insights highlight practical opportunities to adapt learning rates and batch sizes by tracking and the NTK spectrum, potentially guiding more stable and faster training in large-scale models.

Abstract

Recent empirical and theoretical work has shown that the dynamics of the large eigenvalues of the training loss Hessian have some remarkably robust features across models and datasets in the full batch regime. There is often an early period of progressive sharpening where the large eigenvalues increase, followed by stabilization at a predictable value known as the edge of stability. Previous work showed that in the stochastic setting, the eigenvalues increase more slowly - a phenomenon we call conservative sharpening. We provide a theoretical analysis of a simple high-dimensional model which shows the origin of this slowdown. We also show that there is an alternative stochastic edge of stability which arises at small batch size that is sensitive to the trace of the Neural Tangent Kernel rather than the large Hessian eigenvalues. We conduct an experimental study which highlights the qualitative differences from the full batch phenomenology, and suggests that controlling the stochastic edge of stability can help optimization.
Paper Structure (40 sections, 6 theorems, 130 equations, 16 figures)

This paper contains 40 sections, 6 theorems, 130 equations, 16 figures.

Key Result

Theorem 2.1

If the diagonal of $\mathbf{S}$ is governed by Equation eq:pvec_volt, then $\lim_{t\to\infty}{\rm E}_{\mathbf{P}}[\mathbf{z}_{t}\mathbf{z}_{t}^{\top}] = 0$ for any initialization $\mathbf{z}_{t}$ if and only if $||\mathbf{A}||_{op} <1$ and $\mathcal{K} < 1$ where for the PSD matrices $\mathbf{A}$ and $\mathbf{B}$ defined above. $\mathcal{K}$ is always non-negative.

Figures (16)

  • Figure 1: SGD trajectories for linear regression show divergence due to stochastic effects as $\eta$ is increased (left, $B = 5$, $D = 100$, $P = 120$, i.i.d. Gaussian $\mathbf{J}$). $\mathcal{K}$ interpolates from $0$ at small learning rate, to value $1$ precisely when $\lambda_{max}[\mathbf{A}+\mathbf{B}] = 1$ (middle). Loss after $10^{4}$ steps diverges for $\mathcal{K} >1$ (right, plot saturated $10^{1}$ for convenience).
  • Figure 2: Dynamics of largest Hessian eigenvalue in randomly initialized quadratic regression model for fixed learning rate, various batch sizes (averaged over $100$ seeds. Small batch size leads to increased initial sharpening, but faster saturation (left, $V(\sigma) = 1$). Batch size differences are amplified when $\mathbf{Q}$ is more heavily weighted in larger eigenmodes (right, $V(\sigma) = \sigma$).
  • Figure 3: Dynamics of loss (left) and noise kernel norm $\mathcal{K}$ (right) for a FCN trained on MNIST, various learning rates, batch size $1$. For small learning rates, loss decrease is slow and kernel norm is well below $1$. For intermediate learning rates, $\mathcal{K}$ is larger than the critical value of $1$, but then decreases and stabilizes below $1$ and loss decreases quickly. For larger learning rates, $\mathcal{K}$ stays above $1$ for a long period and loss decreases slowly.
  • Figure 4: $\lambda_{max}$ at convergence in MNIST experiment. Left: for large $B$, final values of $\lambda_{max}$ are similar for same $\eta$, especially when dynamics reaches EOS as $2/\eta$ (black dashed line); for small $B$, $\eta$ is not predictive of $\lambda_{max}$ and EOS is not reached. Right: quantities are similar for equal $\eta/B$ for small $B$ and small $\eta/B$.
  • Figure 5: Final noise kernel norm $\mathcal{K}_{f}$ is well predicted by $\eta/B$ for fixed epoch training, and attains a value near $1$ over a large range of learning rates (left). Final loss is poor for $\mathcal{K}_{f}\ll 1$ (conservative steps) but also for $\mathcal{K}$ too close to $1$ (aggressive steps) (middle). $\lambda_{max}$ is not a good predictor of training loss (right).
  • ...and 11 more figures

Theorems & Definitions (11)

  • Theorem 2.1
  • Theorem 3.1
  • Lemma A.1
  • proof
  • Lemma A.2
  • proof
  • Lemma A.3
  • proof
  • Lemma A.4
  • proof
  • ...and 1 more