Table of Contents
Fetching ...

On the SDEs and Scaling Rules for Adaptive Gradient Algorithms

Sadhika Malladi, Kaifeng Lyu, Abhishek Panigrahi, Sanjeev Arora

TL;DR

This work provides rigorously proven Itô SDE approximations for RMSprop and Adam, establishing that these adaptive optimizers admit 1st-order weak approximations and enabling square-root scaling rules when batch size changes. By introducing the Noisy Gradient Oracle with Scale Parameter and carefully bounding approximation errors under well-behaved noise and polynomial-growth conditions, the authors show how to adjust hyperparameters (e.g., $\

Abstract

Approximating Stochastic Gradient Descent (SGD) as a Stochastic Differential Equation (SDE) has allowed researchers to enjoy the benefits of studying a continuous optimization trajectory while carefully preserving the stochasticity of SGD. Analogous study of adaptive gradient methods, such as RMSprop and Adam, has been challenging because there were no rigorously proven SDE approximations for these methods. This paper derives the SDE approximations for RMSprop and Adam, giving theoretical guarantees of their correctness as well as experimental validation of their applicability to common large-scaling vision and language settings. A key practical result is the derivation of a $\textit{square root scaling rule}$ to adjust the optimization hyperparameters of RMSprop and Adam when changing batch size, and its empirical validation in deep learning settings.

On the SDEs and Scaling Rules for Adaptive Gradient Algorithms

TL;DR

This work provides rigorously proven Itô SDE approximations for RMSprop and Adam, establishing that these adaptive optimizers admit 1st-order weak approximations and enabling square-root scaling rules when batch size changes. By introducing the Noisy Gradient Oracle with Scale Parameter and carefully bounding approximation errors under well-behaved noise and polynomial-growth conditions, the authors show how to adjust hyperparameters (e.g., $\

Abstract

Approximating Stochastic Gradient Descent (SGD) as a Stochastic Differential Equation (SDE) has allowed researchers to enjoy the benefits of studying a continuous optimization trajectory while carefully preserving the stochasticity of SGD. Analogous study of adaptive gradient methods, such as RMSprop and Adam, has been challenging because there were no rigorously proven SDE approximations for these methods. This paper derives the SDE approximations for RMSprop and Adam, giving theoretical guarantees of their correctness as well as experimental validation of their applicability to common large-scaling vision and language settings. A key practical result is the derivation of a to adjust the optimization hyperparameters of RMSprop and Adam when changing batch size, and its empirical validation in deep learning settings.
Paper Structure (67 sections, 24 theorems, 89 equations, 30 figures, 5 tables)

This paper contains 67 sections, 24 theorems, 89 equations, 30 figures, 5 tables.

Key Result

Theorem 4.2

Let ${\bm{u}}_k\triangleq{\bm{v}}_k/\sigma^2$ and define the state of the discrete RMSprop trajectory with hyperparameters $\eta,\beta,\epsilon$ (def:rmsprop) as ${\bm{x}}_k = ({\bm{\theta}}_k, {\bm{u}}_k)$. Then, for a well-behaved NGOS (def:NGOS) satisfying the skewness and bounded moments conditi where $g$ and $T$ are defined as in def:weak_approx and the initial condition of the SDE is ${\bm{X

Figures (30)

  • Figure 1: Square root scaling rule experiments on CIFAR-10 with VGG-16 and ResNet-50 (details in \ref{['sec:app_exp_config']}). We plot the mean and variance of 3 random seeds. Same color legend has been used across all the plots. The performance gap between $B=256$ and $B=8192$ is at most $3\%$ in all cases.
  • Figure 2: Large scale square root scaling rule experiments (details in \ref{['sec:app_exp_config']}). Small and large batch models differ by at most $1.5\%$ test accuracy in vision and 0.5 perplexity in language.
  • Figure 3: SVAG on the Adam trajectory when training ResNet-50 on CIFAR-10 matches the discrete trajectory ($\ell=1$) on various test functions (see \ref{['sec:app_exp_config']} for details). The closeness of the trajectories with respect to various test functions for different values of $\ell$ implies the SDE approximation (\ref{['def:adam_sde']}) is a 1st-order weak approximation of Adam (\ref{['thm:adam_sde']}).
  • Figure 4: We compare the norm of the average gradient with the noise scale for different batch sizes during training of ResNet-50 and VGG-16 model with RMSprop on the CIFAR-10 dataset. Here, $(\eta, \beta) = (10^{-3}, 0.999)$ for batch size $256$ and scaled with our proposed square root scaling rule (\ref{['def:rmsprop_scaling']}) for the other batch sizes. We show the results for $\epsilon$ at both small (of order $10^{-30}$) and large scale (of order $10^{-8}$). We observe that for small batches, the noise in the gradient dominates the signal in the gradient, supporting our hypothesis. For larger batches, the hypothesis seems to hold true towards the end of training.
  • Figure 5: We compare the norm of the average gradient with the noise scale for different batch sizes during training of ResNet-50 model with Adam on the CIFAR-10 dataset. Here, $(\eta, \beta_1, \beta_2) = (10^{-3}, 0.999, 0.999)$ for batch size $256$ and scaled with our proposed square root scaling rule (\ref{['def:rmsprop_scaling']}) for the other batch sizes. We show the results for $\epsilon$ at both small (of order $10^{-30}$) and large scale (of order $10^{-8}$). We observe that for small batches, the noise in the gradient dominates the signal in the gradient, supporting our hypothesis. For larger batches, the hypothesis seems to hold true towards the end of training.
  • ...and 25 more figures

Theorems & Definitions (65)

  • Definition 2.1
  • Definition 2.2
  • Definition 2.3
  • Definition 2.4: Order-1 Weak Approximation, li2019stochastic
  • Definition 2.5: Well-behaved NGOS
  • Definition 2.6: Low Skewness Condition
  • Definition 2.7: Bounded Moments Condition
  • Definition 4.1: SDE for RMSprop
  • Theorem 4.2: Informal version of \ref{['thm:app_rmsprop_sde']}
  • Remark 4.3
  • ...and 55 more