Table of Contents
Fetching ...

Stochastic modified equations and adaptive stochastic gradient algorithms

Qianxiao Li, Cheng Tai, Weinan E

TL;DR

This work introduces stochastic modified equations (SME) to rigorously approximate SGD in the weak sense, enabling precise dynamical analysis of descent and noise-driven fluctuations beyond convex regimes. By combining SME with optimal control, it derives adaptive learning-rate and momentum policies, yielding robust algorithms (cSGD and cMSGD) that require less hyper-parameter tuning across diverse models and datasets. Theoretical results include first- and second-order weak approximations and stochastic asymptotic expansions, while empirical benchmarks on MNIST and CIFAR-10 validate competitive performance and adaptability. Overall, SME provides a general methodology for analyzing and designing stochastic gradient algorithms with practical, model-agnostic adaptivity.

Abstract

We develop the method of stochastic modified equations (SME), in which stochastic gradient algorithms are approximated in the weak sense by continuous-time stochastic differential equations. We exploit the continuous formulation together with optimal control theory to derive novel adaptive hyper-parameter adjustment policies. Our algorithms have competitive performance with the added benefit of being robust to varying models and datasets. This provides a general methodology for the analysis and design of stochastic gradient algorithms.

Stochastic modified equations and adaptive stochastic gradient algorithms

TL;DR

This work introduces stochastic modified equations (SME) to rigorously approximate SGD in the weak sense, enabling precise dynamical analysis of descent and noise-driven fluctuations beyond convex regimes. By combining SME with optimal control, it derives adaptive learning-rate and momentum policies, yielding robust algorithms (cSGD and cMSGD) that require less hyper-parameter tuning across diverse models and datasets. Theoretical results include first- and second-order weak approximations and stochastic asymptotic expansions, while empirical benchmarks on MNIST and CIFAR-10 validate competitive performance and adaptability. Overall, SME provides a general methodology for analyzing and designing stochastic gradient algorithms with practical, model-agnostic adaptivity.

Abstract

We develop the method of stochastic modified equations (SME), in which stochastic gradient algorithms are approximated in the weak sense by continuous-time stochastic differential equations. We exploit the continuous formulation together with optimal control theory to derive novel adaptive hyper-parameter adjustment policies. Our algorithms have competitive performance with the added benefit of being robust to varying models and datasets. This provides a general methodology for the analysis and design of stochastic gradient algorithms.

Paper Structure

This paper contains 43 sections, 4 theorems, 131 equations, 8 figures, 1 table, 3 algorithms.

Key Result

Theorem 1

Let $\alpha\in\{1,2\}$, $0<\eta<1$, $T>0$ and set $N=\lfloor T/\eta\rfloor$. Let $x_k\in\mathbb{R^d}$, $0\leq k\leq N$ denote a sequence of SGD iterations defined by (2). Define $X_t\in \mathbb{R}^d$ as the stochastic process satisfying the SDE $X_0=x_0$ and $\Sigma(x) = \frac{1}{n} \sum_{i=1}^n (\nabla f(x) - \nabla f_i(x))(\nabla f(x) - \nabla f_i(x))^T$. Fix some test function $g\in G$ (c.f. D

Figures (8)

  • Figure 1: Comparison of the SME predictions vs SGD for the simple quadratic objective. We set $x_0=1$, $\eta=$5e-3. (a) The predicted mean and standard deviations agree well with the empirical moments of the SGD, obtained by averaging 5e3 runs. (b) 50 sample SGD paths the predicted transition time $k^*=t^*/\eta$. We observe that $k^*$ corresponds to the separation of descent and fluctuating regimes for typical sample paths.
  • Figure 2: Comparison of the moments of SGD iterates with the SME and its asymptotic approximation (Asymp, Eq. \ref{['eq:asymp']}) for the non-convex objective with $\delta=0.2$ and $\epsilon=0.1$. The landscape is shown in (a). In (b), we plot the magnitude of the mean and the covariance matrix for the SGD, SME and Asymp. We take $\eta=$1e-4 and $x_0=(1,1.5)$. All moments are obtained by sampling over 1e3 runs (the SME and Asymp are integrated numerically). We observe a good agreement.
  • Figure 3: cSGD vs Adagrad and Adam for different models and datasets, with different hyper-parameters. For M0, we perform log-uniform random search with 50 samples over intervals: cSGD: $u_0\in$[1e-2,1], $\eta\in$[1e-1,1]; Adagrad: $\eta\in$[1e-3,1]; Adam: $\eta\in$[1e-4,1e-1]. For C0, we perform same search over intervals: cSGD: $u_0\in$[1e-2,1], $\eta\in$[1e-1,1]; Adagrad: $\eta\in$[1e-3,1]; Adam: $\eta\in$[1e-6,1e-3]. We average the resulting learning curves for each choice over 10 runs. For C1, due to long training times we choose 5 representative learning rates for each method. cSGD: $\eta\in${1e-2,5e-2,1e-1,5e-1,1}, $u_0=1$; Adagrad: $\eta\in${1e-3,5e-3,1e-2,5e-2,1e-1}; Adam: $\eta\in${5e-4,1e-3,1e-2,2e-2,5e-2}. One sample learning curve is generated for each choice. In all cases, we use mini-batches of size 128. We evaluate the resulting learning curves by the area-under-curve. The worst, median and best learning curves are shown as dotted, solid, and dot-dashed lines respectively. The shaded areas represent the distribution of learning curves for all searched values. We observe that cSGD is relatively robust with respect to initial/maximum learning rates and the network structures, and requires little tuning while having comparable performance to well-tuned versions of the other methods (see Tab. \ref{['tab:test_acc']}). This holds across different models and datasets.
  • Figure 4: (a) Comparison of the SME prediction \ref{['eq:M_eqn']} with SGD for the same quadratic example in Sec. \ref{['sec:dynamics_example']}, which has $a=2$, $b=0$ and $\Sigma=4$. We set $\eta=$5e-3 so that $\mu_{\text{opt}}=0.8$. We plot the mean of $f$ averaged over 1e5 SGD runs against the SME predictions for $\mu=0.65,0.8,0.95$. We observe that in all cases the approximation is accurate. In particular, the SME correctly predicts the effect of momentum: $\mu=\mu_{\text{opt}}$ gives the best average initial descent rate, $\mu>\mu_{\text{opt}}$ causes oscillatory behavior, and increasing $\mu$ generally increases asymptotic fluctuations. (b) The dynamics of averaged equation \ref{['eq:m_eqn']}, which serves as an approximation of the solution of the full SME moment equation \ref{['eq:M_eqn']}.
  • Figure 5: cMSGD vs MSGD and MSGD-A on the same three models. We set $\eta=$1e-2 for M0 and $\eta=$1e-3 for C0, C1. For M0 and C0, we perform a log-uniform random search for $1-\mu_0$ and $1-\mu$ in [5e-3,5e-1]. For C1, we sample $\mu_0,\mu,\mu_\text{max}\in${0.9,0.95,0.99,0.995,0.999}. The remaining set-up is identical to that in Fig. \ref{['fig:MNIST_test_lr']}. Again, we observe that cMSGD is an adaptive scheme that is robust to varying hyper-parameters and network structures, and out-performs MSGD and MSGD-A.
  • ...and 3 more figures

Theorems & Definitions (14)

  • Definition 1
  • Remark 1
  • Remark 2
  • Remark 3
  • Theorem 1: Stochastic modified equations
  • Lemma 1
  • proof
  • Lemma 2
  • proof
  • Theorem 2: Milstein, 1986
  • ...and 4 more