On the Validity of Modeling SGD with Stochastic Differential Equations (SDEs)
Zhiyuan Li, Sadhika Malladi, Sanjeev Arora
TL;DR
The paper investigates the validity of modeling finite-LR SGD with the Itô SDE $dX_t = - \nabla \mathcal{L}(X_t) dt + (\eta \Sigma(X_t))^{1/2} dW_t$ by introducing Stochastic Variance Amplified Gradient (SVAG), a provably convergent simulator that weakly approximates the SDE as the amplification factor $l$ grows. It proves a strong, order-1 weak convergence, provides a testable NSR-based condition for the SDE approximation and Linear Scaling Rule (LSR) to hold, and demonstrates via extensive experiments that SVAG trajectories match SGD in standard vision tasks, while breaking away when LSR fails. The analysis also shows non-Gaussian gradient noise is not essential for performance and that LSR failures can be anticipated from equilibrium statistics. Overall, the work validates the Itô SDE perspective as a meaningful lens for SGD dynamics in realistic networks and offers practical guidance for hyperparameter choices and interpreting generalization phenomena.
Abstract
It is generally recognized that finite learning rate (LR), in contrast to infinitesimal LR, is important for good generalization in real-life deep nets. Most attempted explanations propose approximating finite-LR SGD with Ito Stochastic Differential Equations (SDEs), but formal justification for this approximation (e.g., (Li et al., 2019)) only applies to SGD with tiny LR. Experimental verification of the approximation appears computationally infeasible. The current paper clarifies the picture with the following contributions: (a) An efficient simulation algorithm SVAG that provably converges to the conventionally used Ito SDE approximation. (b) A theoretically motivated testable necessary condition for the SDE approximation and its most famous implication, the linear scaling rule (Goyal et al., 2017), to hold. (c) Experiments using this simulation to demonstrate that the previously proposed SDE approximation can meaningfully capture the training and generalization properties of common deep nets.
