Table of Contents
Fetching ...

On the Validity of Modeling SGD with Stochastic Differential Equations (SDEs)

Zhiyuan Li, Sadhika Malladi, Sanjeev Arora

TL;DR

The paper investigates the validity of modeling finite-LR SGD with the Itô SDE $dX_t = - \nabla \mathcal{L}(X_t) dt + (\eta \Sigma(X_t))^{1/2} dW_t$ by introducing Stochastic Variance Amplified Gradient (SVAG), a provably convergent simulator that weakly approximates the SDE as the amplification factor $l$ grows. It proves a strong, order-1 weak convergence, provides a testable NSR-based condition for the SDE approximation and Linear Scaling Rule (LSR) to hold, and demonstrates via extensive experiments that SVAG trajectories match SGD in standard vision tasks, while breaking away when LSR fails. The analysis also shows non-Gaussian gradient noise is not essential for performance and that LSR failures can be anticipated from equilibrium statistics. Overall, the work validates the Itô SDE perspective as a meaningful lens for SGD dynamics in realistic networks and offers practical guidance for hyperparameter choices and interpreting generalization phenomena.

Abstract

It is generally recognized that finite learning rate (LR), in contrast to infinitesimal LR, is important for good generalization in real-life deep nets. Most attempted explanations propose approximating finite-LR SGD with Ito Stochastic Differential Equations (SDEs), but formal justification for this approximation (e.g., (Li et al., 2019)) only applies to SGD with tiny LR. Experimental verification of the approximation appears computationally infeasible. The current paper clarifies the picture with the following contributions: (a) An efficient simulation algorithm SVAG that provably converges to the conventionally used Ito SDE approximation. (b) A theoretically motivated testable necessary condition for the SDE approximation and its most famous implication, the linear scaling rule (Goyal et al., 2017), to hold. (c) Experiments using this simulation to demonstrate that the previously proposed SDE approximation can meaningfully capture the training and generalization properties of common deep nets.

On the Validity of Modeling SGD with Stochastic Differential Equations (SDEs)

TL;DR

The paper investigates the validity of modeling finite-LR SGD with the Itô SDE by introducing Stochastic Variance Amplified Gradient (SVAG), a provably convergent simulator that weakly approximates the SDE as the amplification factor grows. It proves a strong, order-1 weak convergence, provides a testable NSR-based condition for the SDE approximation and Linear Scaling Rule (LSR) to hold, and demonstrates via extensive experiments that SVAG trajectories match SGD in standard vision tasks, while breaking away when LSR fails. The analysis also shows non-Gaussian gradient noise is not essential for performance and that LSR failures can be anticipated from equilibrium statistics. Overall, the work validates the Itô SDE perspective as a meaningful lens for SGD dynamics in realistic networks and offers practical guidance for hyperparameter choices and interpreting generalization phenomena.

Abstract

It is generally recognized that finite learning rate (LR), in contrast to infinitesimal LR, is important for good generalization in real-life deep nets. Most attempted explanations propose approximating finite-LR SGD with Ito Stochastic Differential Equations (SDEs), but formal justification for this approximation (e.g., (Li et al., 2019)) only applies to SGD with tiny LR. Experimental verification of the approximation appears computationally infeasible. The current paper clarifies the picture with the following contributions: (a) An efficient simulation algorithm SVAG that provably converges to the conventionally used Ito SDE approximation. (b) A theoretically motivated testable necessary condition for the SDE approximation and its most famous implication, the linear scaling rule (Goyal et al., 2017), to hold. (c) Experiments using this simulation to demonstrate that the previously proposed SDE approximation can meaningfully capture the training and generalization properties of common deep nets.

Paper Structure

This paper contains 39 sections, 20 theorems, 80 equations, 20 figures.

Key Result

Theorem 4.3

Suppose the following conditionsThe $\mathcal{C}^\infty$ smoothness assumptions can be relaxed by using the mollification technique in li2019stochastic. are met: Let $T>0$ be a constant and $l$ be the SVAG hyperparameter (eq:svag_iter). Define $\{ X_t: t\in[0,T] \}$ as the stochastic process (independent of $\eta$) satisfying the Itô SDE (eq:motivating_sde) and $\{x_k^{\eta/l}: 1\le k\le \lfloor

Figures (20)

  • Figure 1: Itô SDE (\ref{['eq:motivating_sde']}), SVAG (\ref{['eq:svag_iter']}), and SGD (\ref{['eq:sgd_iter']}) trajectories (blue) sampled from a distribution (green). li2019stochastic show that $\forall T, \exists \eta$ such that SDE (a) and SGD (c) are order-1 weak approximations (\ref{['def:weak_approx']}) of each other. Our result (\ref{['thm:sde_svag']}) shows that $\forall T, \eta$, $\exists l$ such that SDE (a) and SVAG (b) are order-1 weak approximations of each other. In particular, li2019stochastic requires an infinitesimal $\eta$ and our result holds for finite $\eta$.
  • Figure 2: Experimental verification for our theory on predicting the failure of Linear Scaling Rule. We modify PreResNet-32 and VGG-19 to be scale-invariant (according to Appendix C of li2020reconciling). All three settings use the same LR schedule, LR$=0.8$ initially and is decayed by $0.1$ at epoch $250$ with $300$ epochs total budget. Here, $G_t$ and $N_t$ are the empirical estimations of $G_\infty$ and $N_\infty$ taken after reaching equilibrium in the first phase (before LR decay). Per the approximated version of \ref{['thm:lsr_kappa_bound']}, i.e., $B^*=\kappa B \lesssim C^2B{N_\infty^B}/{G_\infty^B}$, we use baseline runs with different batch sizes $B$ to report the maximal and minimal predicted critical batch size, defined as the intersection of the threshold ($G_t/N_t=C^2$) with the green and blue lines, respectively. We choose a threshold of $C^2 = 2$, and consider LSR to fail if the final test error exceeds the lowest achieved test error by more than 20% of its value, marked by the red region on the plot. Further settings and discussion are in \ref{['sec:app_exp']}.
  • Figure 3: Non-Gaussian noise is not essential to SGD performance. SGD with batch size $125$ and NGD with matching covariance have close train and test curves when training on CIFAR-10. $\eta=0.8$ for all three settings and is decayed by $0.1$ at step $24000$. GD achieves 75.5% test accuracy, and SGD and NGD achieve 89.4% and 89.3%, respectively. We smooth the training curve by dividing it into intervals of 100 steps and recording the average. For efficient sampling of Gaussian noise, we use GroupNorm instead of BatchNorm and turn off data augmentation. See implementation details in \ref{['sec:app_exp']}.
  • Figure 4:
  • Figure 5: Taking $\eta\to 0$ and keeping the first two moments are not enough to converge to Itô SDE limit, e.g. decreasing LR along LSR can converge to another limit, Lévy SDE. Red and blue arrows means taking limit of the dynamics when LR $\eta\to 0$ along the SVAG and LSR respectively. Here we assume the noise in SGD is infinitely divisible such that the LR can go to $0$ along LSR. For NGD, i.e., SGD with Gaussian noise, both SVAG and LSR (Linear Scaling Rule) approaches the same continuous limit. This does not hold for SGD with non-Gaussian noise.
  • ...and 15 more figures

Theorems & Definitions (48)

  • Definition 2.1: Linear Scaling Rule (LSR)
  • Definition 4.1: Test Functions
  • Definition 4.2: Order-$\alpha$ weak approximation
  • Theorem 4.3
  • Remark 4.4
  • Lemma 4.4
  • Lemma 4.4
  • Definition 5.1: $C$-closeness
  • Theorem 5.2
  • Remark 5.3
  • ...and 38 more