Table of Contents
Fetching ...

A PDE-based Explanation of Extreme Numerical Sensitivities and Edge of Stability in Training Neural Networks

Yuxin Sun, Dong Lao, Ganesh Sundaramoorthi, Anthony Yezzi

TL;DR

This work presents a PDE-based framework to explain restrained numerical instabilities observed during SGD training of CNNs. By deriving a gradient-flow PDE and its discretization, the authors establish CFL-type stability bounds linking learning rate and weight decay, and show that instabilities can remain localized and oscillatory (restrained) rather than causing global divergence. They connect these restrained instabilities to the Edge of Stability (EoS), offering mechanistic insight into why practical training often operates beyond classical stability limits and how regularization and network depth influence these dynamics. The findings suggest adaptive, space-time localized learning-rate strategies to maintain stability while preserving training efficiency, with broad implications for understanding and mitigating numerical conditioning issues in deep learning practice.

Abstract

We discover restrained numerical instabilities in current training practices of deep networks with stochastic gradient descent (SGD), and its variants. We show numerical error (on the order of the smallest floating point bit and thus the most extreme or limiting numerical perturbations induced from floating point arithmetic in training deep nets can be amplified significantly and result in significant test accuracy variance (sensitivities), comparable to the test accuracy variance due to stochasticity in SGD. We show how this is likely traced to instabilities of the optimization dynamics that are restrained, i.e., localized over iterations and regions of the weight tensor space. We do this by presenting a theoretical framework using numerical analysis of partial differential equations (PDE), and analyzing the gradient descent PDE of convolutional neural networks (CNNs). We show that it is stable only under certain conditions on the learning rate and weight decay. We show that rather than blowing up when the conditions are violated, the instability can be restrained. We show this is a consequence of the non-linear PDE associated with the gradient descent of the CNN, whose local linearization changes when over-driving the step size of the discretization, resulting in a stabilizing effect. We link restrained instabilities to the recently discovered Edge of Stability (EoS) phenomena, in which the stable step size predicted by classical theory is exceeded while continuing to optimize the loss and still converging. Because restrained instabilities occur at the EoS, our theory provides new insights and predictions about the EoS, in particular, the role of regularization and the dependence on the network complexity.

A PDE-based Explanation of Extreme Numerical Sensitivities and Edge of Stability in Training Neural Networks

TL;DR

This work presents a PDE-based framework to explain restrained numerical instabilities observed during SGD training of CNNs. By deriving a gradient-flow PDE and its discretization, the authors establish CFL-type stability bounds linking learning rate and weight decay, and show that instabilities can remain localized and oscillatory (restrained) rather than causing global divergence. They connect these restrained instabilities to the Edge of Stability (EoS), offering mechanistic insight into why practical training often operates beyond classical stability limits and how regularization and network depth influence these dynamics. The findings suggest adaptive, space-time localized learning-rate strategies to maintain stability while preserving training efficiency, with broad implications for understanding and mitigating numerical conditioning issues in deep learning practice.

Abstract

We discover restrained numerical instabilities in current training practices of deep networks with stochastic gradient descent (SGD), and its variants. We show numerical error (on the order of the smallest floating point bit and thus the most extreme or limiting numerical perturbations induced from floating point arithmetic in training deep nets can be amplified significantly and result in significant test accuracy variance (sensitivities), comparable to the test accuracy variance due to stochasticity in SGD. We show how this is likely traced to instabilities of the optimization dynamics that are restrained, i.e., localized over iterations and regions of the weight tensor space. We do this by presenting a theoretical framework using numerical analysis of partial differential equations (PDE), and analyzing the gradient descent PDE of convolutional neural networks (CNNs). We show that it is stable only under certain conditions on the learning rate and weight decay. We show that rather than blowing up when the conditions are violated, the instability can be restrained. We show this is a consequence of the non-linear PDE associated with the gradient descent of the CNN, whose local linearization changes when over-driving the step size of the discretization, resulting in a stabilizing effect. We link restrained instabilities to the recently discovered Edge of Stability (EoS) phenomena, in which the stable step size predicted by classical theory is exceeded while continuing to optimize the loss and still converging. Because restrained instabilities occur at the EoS, our theory provides new insights and predictions about the EoS, in particular, the role of regularization and the dependence on the network complexity.
Paper Structure (31 sections, 8 theorems, 90 equations, 11 figures, 9 tables)

This paper contains 31 sections, 8 theorems, 90 equations, 11 figures, 9 tables.

Key Result

Theorem 5.1

The gradient descent PDE with respect to the loss eq:loss is where $r'$ denotes the derivative of the activation, and $\partial_t$ denotes the partial derivative with respect to time. If we impose the constraint that $K$ has compact support ($K$ is zero outside $[-w/2,w/2]^2$), then the constrained (or projected) $\mathbb{L}^2$ gradient descent is given by where $W$ is a windowing function ($1$

Figures (11)

  • Figure 1: Illustration of instability in discretizing the heat equation. The initial condition $u_0$ is a triangle (blue), with boundary conditions $u(0)=u(30)=0$. [Left]: When the CFL condition is met, i.e., $\kappa \Delta t/ (\Delta x)^2 = 0.4 < \frac{1}{2}$, the method is stable and approximates the solution of the PDE. Note the true steady state is 0, which matches the plot. [Right]: When the CFL condition is not met, i.e., $\kappa \Delta t/(\Delta x)^2 = 0.8 > \frac{1}{2}$, small numerical errors are amplified and the scheme diverges.
  • Figure 2: Demonstration of restrained instability related to "moderately aggressive" choices of step size in the discretized Beltrami PDE. While temporary instabilities continuously develop locally throughout the optimization process, the oscillations they generate change the local linearization which in turn has a local stabilizing effect (temporarily) that restrains the instability from exploding. As such, we can optimize, despite these restrained instabilities, to obtain a candidate minimizer which contains a level of effective "background-noise" arising from the restrained instabilities when exceeding the maximum stable steps size by a factor of 10 (left) and 100 (middle). Both converge in a stochastic sense but with a degraded signal-to-noise ratio for the larger step size. Continued optimization steps beyond this effective convergence (green curves) simply change the noise pattern (red curves). If the stable step size is exceeded too much (e.g., by a factor of 1000 on the right plot), the process becomes unstable without restraint and does not converge even in a stochastic sense.
  • Figure 3: [Left]: Relative L1 difference in weights across epochs for SGD ($k=1$) and modified SGD ($k=3$). A decayed learning rate schedule is used. The errors quickly build up, but then are contained at higher epochs. The initial build-up of errors is enough to result in different test accuracy. This gives evidence of restrained instability. [Right]: Relative L1 difference in weights over initial iterations for SGD ($k=1$) and modified SGD ($k=3$) with various fixed learning rates. Lower than some minimum learning rate, the floating point perturbation gets attenuated and higher than that value errors are amplified, suggesting an instability.
  • Figure 4: Empirical validation of stability bounds for the linearized PDE \ref{['eq:linear_PDE']}. When the weight decay is chosen such that $\alpha \in (\alpha_{min}, \alpha_{max})$, the PDE is stable and the loss converges, otherwise, the PDE can be unstable and diverge. The pink and brown lines are clipped as the loss exceeded the largest float.
  • Figure 5: [Left]: Loss vs iterations of the non-linear PDE \ref{['eq:PDE_constraint']} for various choices of learning rates ($\Delta t$). The loss for various learning rates are consistent with the theory predicted for fully stable, fully unstable and restrained instabilities. [Right]: L1 error accumulation in the non-linear PDE. The plot is consistent with expected error accumulation for restrained instabilities, and fully stable regimes.
  • ...and 6 more figures

Theorems & Definitions (8)

  • Theorem 5.1: Gradient Descent PDE of the Loss \ref{['eq:loss']}
  • Theorem 5.2: Linearized PDE
  • Theorem 5.3: DFT of Discretization
  • Theorem 5.4: Stability Conditions
  • Theorem 5.5: Nesterov Accelerated Gradient Descent and its Underlying PDE
  • Theorem 5.6: Stability Conditions for Nesterov Gradient Descent
  • Theorem 5.7: Multi-layer CNN Kernel Gradient
  • Theorem 5.1