Table of Contents
Fetching ...

Universal Sharpness Dynamics in Neural Network Training: Fixed Point Analysis, Edge of Stability, and Route to Chaos

Dayal Singh Kalra, Tianyu He, Maissam Barkeshli

TL;DR

The paper investigates robust sharpness dynamics during neural network training by analyzing a minimal two-layer linear UV model, which recapitulates early sharpness reduction, progressive sharpening, and edge of stability observed in real networks. Through fixed-point analysis in function space, it derives a critical learning rate $\eta_c$ and reveals an EoS attractor on a coupled $(\Delta f, \lambda)$ manifold, with a period-doubling route to chaos as $\eta$ increases. The work then validates these predictions in realistic architectures and datasets, showing that initialization and parameterization strongly shape sharpness trajectories and that the EoS phase diagram generalizes beyond the UV model. While the UV model provides a coherent explanatory framework for sharpness phenomena, it also has limitations in capturing nonmonotonic loss dynamics and long-range correlations in all real-world settings, motivating future work on extending the fixed-point approach to broader nonlinear regimes.

Abstract

In gradient descent dynamics of neural networks, the top eigenvalue of the loss Hessian (sharpness) displays a variety of robust phenomena throughout training. This includes early time regimes where the sharpness may decrease during early periods of training (sharpness reduction), and later time behavior such as progressive sharpening and edge of stability. We demonstrate that a simple $2$-layer linear network (UV model) trained on a single training example exhibits all of the essential sharpness phenomenology observed in real-world scenarios. By analyzing the structure of dynamical fixed points in function space and the vector field of function updates, we uncover the underlying mechanisms behind these sharpness trends. Our analysis reveals (i) the mechanism behind early sharpness reduction and progressive sharpening, (ii) the required conditions for edge of stability, (iii) the crucial role of initialization and parameterization, and (iv) a period-doubling route to chaos on the edge of stability manifold as learning rate is increased. Finally, we demonstrate that various predictions from this simplified model generalize to real-world scenarios and discuss its limitations.

Universal Sharpness Dynamics in Neural Network Training: Fixed Point Analysis, Edge of Stability, and Route to Chaos

TL;DR

The paper investigates robust sharpness dynamics during neural network training by analyzing a minimal two-layer linear UV model, which recapitulates early sharpness reduction, progressive sharpening, and edge of stability observed in real networks. Through fixed-point analysis in function space, it derives a critical learning rate and reveals an EoS attractor on a coupled manifold, with a period-doubling route to chaos as increases. The work then validates these predictions in realistic architectures and datasets, showing that initialization and parameterization strongly shape sharpness trajectories and that the EoS phase diagram generalizes beyond the UV model. While the UV model provides a coherent explanatory framework for sharpness phenomena, it also has limitations in capturing nonmonotonic loss dynamics and long-range correlations in all real-world settings, motivating future work on extending the fixed-point approach to broader nonlinear regimes.

Abstract

In gradient descent dynamics of neural networks, the top eigenvalue of the loss Hessian (sharpness) displays a variety of robust phenomena throughout training. This includes early time regimes where the sharpness may decrease during early periods of training (sharpness reduction), and later time behavior such as progressive sharpening and edge of stability. We demonstrate that a simple -layer linear network (UV model) trained on a single training example exhibits all of the essential sharpness phenomenology observed in real-world scenarios. By analyzing the structure of dynamical fixed points in function space and the vector field of function updates, we uncover the underlying mechanisms behind these sharpness trends. Our analysis reveals (i) the mechanism behind early sharpness reduction and progressive sharpening, (ii) the required conditions for edge of stability, (iii) the crucial role of initialization and parameterization, and (iv) a period-doubling route to chaos on the edge of stability manifold as learning rate is increased. Finally, we demonstrate that various predictions from this simplified model generalize to real-world scenarios and discuss its limitations.
Paper Structure (66 sections, 1 theorem, 28 equations, 32 figures, 1 table)

This paper contains 66 sections, 1 theorem, 28 equations, 32 figures, 1 table.

Key Result

Corollary 5.1

Let $\eta_{\mathrm{max}}$ be the maximum trainable learning rate for a given initialization. The bifurcation diagram is observed up to $\eta < \eta_{\mathrm{max}}$. If $\eta_{\mathrm{max}} < \eta_c$, the UV model does not exhibit EoS.

Figures (32)

  • Figure 1: Training loss and sharpness trajectories of ReLU FCNs trained on a $5$k subset of CIFAR-10 examples using MSE loss and GD: (a, d) SP with $\sigma^2_w = 0.5$, (b, e) SP with $\sigma^2_w = 2.0$, (c, f) $\mu$P with $\sigma^2_w = 2.0$. The dashed lines in the sharpness figures show the $2/\eta$ threshold.
  • Figure 2: Training trajectories of the UV model with $\|\bm{x}\| = 1$ and $y=2$ in the $(\Delta f, \lambda)$ plane for different values of $n$, $n_{\text{eff}}$ and $\eta$. The columns show initializations with different $n$ and $n_{\text{eff}}$, while the rows represent increasing learning rates for fixed initializations. The horizontal dash-dot line $\eta \lambda = 2$ separates the stable (solid black vertical line) and unstable (dashed black vertical line) fixed points along the zero loss fixed line I. Forbidden regions, $2 \|\bm{x}\| |\Delta f + y| / \sqrt{n_{\text{eff}}} > \lambda$, (see \ref{['appendix:forbidden_regions']}) are shaded gray. The nullclines $\Delta f_{t+1} = \Delta f_t$ and $\lambda_{t+1} = \lambda_t$ are shown as orange and white dashed curves, respectively. Sharpness reduction, progressive sharpening, and divergent regions are colored green, yellow, and blue. The gray arrows indicate the local vector field $\hat{G}(\Delta f, \lambda)$, which is the direction of the updates. The training trajectories are depicted as black lines with arrows, with the star marking the initialization. In all cases, $\eta_c = \sqrt{n_{\text{eff}}} / 2$ (introduced in \ref{['section:eos_UV_model']}).
  • Figure 3: (a) Bifurcation diagram depicting limiting values of $\lambda$ obtained by simulating \ref{['equation:uv_eos_map']}. (b) Bifurcation diagram of the UV model. In both figures, $\| \bm{x}\| = 1, y = 2$, $n_{\text{eff}}$ = 1 and $\eta_c = 0.5$.
  • Figure 4: (a, b) Heatmap of $\eta {\lambda}^H/2$ and test accuracy of ReLU FCNs in SP trained on a $5$k subset of CIFAR-10 until $99\%$ training accuracy is achieved, with the weight variance $\sigma^2_w$ and learning rate multiplier $c = \eta \lambda_0^H$ as axes. As the color varies from blue to white, $\eta {\lambda}^H/2$ increases. (b, d) Same heatmaps with fixed $\sigma_w^2=2.0$, but varying $s$ continuously.
  • Figure 5: $2$-layer linear FCNs trained on (first row) $5,000$ iid random examples with unit output dimension and (second row) $5,000$ CIFAR-10 examples. Different columns correspond to the bifurcation diagram, late-time sharpness trajectories, and the power spectrum of sharpness trajectories. The power spectrum is computed using the last $1000$ steps of the trajectories.
  • ...and 27 more figures

Theorems & Definitions (1)

  • Corollary 5.1