Table of Contents
Fetching ...

Adaptive Methods through the Lens of SDEs: Theoretical Insights on the Role of Noise

Enea Monzio Compagnoni, Tianlin Liu, Rustem Islamov, Frank Norbert Proske, Antonio Orvieto, Aurelien Lucchi

TL;DR

This work develops a stochastic-differential-equation framework to theoretically analyze adaptive optimizers in deep learning, deriving the first SDE for SignSGD under general gradient-noise assumptions and revealing three distinct phases that govern its dynamics. It extends the SDE approach to decoupled-weight-decay variants of Adam and RMSprop (AdamW and RMSpropW), deriving novel scaling rules and characterizing their stationary distributions and asymptotic losses. The authors validate the SDE models through Euler–Maruyama integration on networks from MLPs to Transformers, showing that the new SDEs track the actual optimizers more faithfully than prior models, especially near minima. The results illuminate how gradient noise, curvature, and decoupled weight decay interact to stabilize training and inform practical scaling strategies for large-scale models.

Abstract

Despite the vast empirical evidence supporting the efficacy of adaptive optimization methods in deep learning, their theoretical understanding is far from complete. This work introduces novel SDEs for commonly used adaptive optimizers: SignSGD, RMSprop(W), and Adam(W). These SDEs offer a quantitatively accurate description of these optimizers and help illuminate an intricate relationship between adaptivity, gradient noise, and curvature. Our novel analysis of SignSGD highlights a noteworthy and precise contrast to SGD in terms of convergence speed, stationary distribution, and robustness to heavy-tail noise. We extend this analysis to AdamW and RMSpropW, for which we observe that the role of noise is much more complex. Crucially, we support our theoretical analysis with experimental evidence by verifying our insights: this includes numerically integrating our SDEs using Euler-Maruyama discretization on various neural network architectures such as MLPs, CNNs, ResNets, and Transformers. Our SDEs accurately track the behavior of the respective optimizers, especially when compared to previous SDEs derived for Adam and RMSprop. We believe our approach can provide valuable insights into best training practices and novel scaling rules.

Adaptive Methods through the Lens of SDEs: Theoretical Insights on the Role of Noise

TL;DR

This work develops a stochastic-differential-equation framework to theoretically analyze adaptive optimizers in deep learning, deriving the first SDE for SignSGD under general gradient-noise assumptions and revealing three distinct phases that govern its dynamics. It extends the SDE approach to decoupled-weight-decay variants of Adam and RMSprop (AdamW and RMSpropW), deriving novel scaling rules and characterizing their stationary distributions and asymptotic losses. The authors validate the SDE models through Euler–Maruyama integration on networks from MLPs to Transformers, showing that the new SDEs track the actual optimizers more faithfully than prior models, especially near minima. The results illuminate how gradient noise, curvature, and decoupled weight decay interact to stabilize training and inform practical scaling strategies for large-scale models.

Abstract

Despite the vast empirical evidence supporting the efficacy of adaptive optimization methods in deep learning, their theoretical understanding is far from complete. This work introduces novel SDEs for commonly used adaptive optimizers: SignSGD, RMSprop(W), and Adam(W). These SDEs offer a quantitatively accurate description of these optimizers and help illuminate an intricate relationship between adaptivity, gradient noise, and curvature. Our novel analysis of SignSGD highlights a noteworthy and precise contrast to SGD in terms of convergence speed, stationary distribution, and robustness to heavy-tail noise. We extend this analysis to AdamW and RMSpropW, for which we observe that the role of noise is much more complex. Crucially, we support our theoretical analysis with experimental evidence by verifying our insights: this includes numerically integrating our SDEs using Euler-Maruyama discretization on various neural network architectures such as MLPs, CNNs, ResNets, and Transformers. Our SDEs accurately track the behavior of the respective optimizers, especially when compared to previous SDEs derived for Adam and RMSprop. We believe our approach can provide valuable insights into best training practices and novel scaling rules.

Paper Structure

This paper contains 71 sections, 67 theorems, 145 equations, 15 figures, 1 table, 1 algorithm.

Key Result

Theorem 3.2

Under sufficient regularity conditions, the solution of the following SDE is an order $1$ weak approximation of the discrete update of SignSGD: where $\bar{\Sigma}(x)$ is the noise covariance $\bar{\Sigma}(x) = \mathbb{E}[\xi_{\gamma}(x)\xi_{\gamma}(x)^\top]$, and $\xi_{\gamma}(x):= \mathop{\mathrm{sign}}\nolimits (\nabla f_{\gamma}(x)) - 1 + 2 \mathbb{P}(\nabla f_{\gamma}(x)<0)$ is the noise of

Figures (15)

  • Figure 1: Comparison of SignSGD and its SDE in terms of $f(x)$: Our SDE successfully tracks the dynamics of SignSGD on several architectures, datasets, and hyperparameters: DNN on the Breast Cancer dataset (Top-Left); CNN on MNIST (Top-Right); Transformer on MNIST (Bottom-Left); ResNet on CIFAR-10 (Bottom-Right).
  • Figure 2: Phases of SignSGD: The ODE of Phase 1 and SDE of Phase 3 overlap with the "Full" SDE as per Lemma \ref{['lemma:three_phases_Insights']}. In Phase 2, the dynamics satisfies the prescribed bounds (Top-Left); Phases of the Loss: The bounds derived in Lemma \ref{['lemma:SignSGD_dynam_loss_Insights']} for the loss during the different phases correctly track the loss evolution (Top-Right); The dynamics of the moments of $X_t$ predicted in Lemma \ref{['lemma:SignSGD_StaDistr_insights']} track the empirical ones (Bottom-Left); If the schedulers satisfy the condition in Lemma \ref{['lemma:Schedulers_Insights']}, the loss decays to $0$ as prescribed. If not, the loss does not converge to $0$ (Bottom-Right). For each figure, $f(x) = \frac{x^{\top}H x}{2}$ for $H = \mathop{\mathrm{diag}}\nolimits(1,2)$, $\eta =0.001$, and $\Sigma = \sigma^2 I_2$ where $\sigma = 0.1$.
  • Figure 3: The two images at the top compare the SDEs of AdamW and RMSpropW with the respective optimizers in terms of trajectories and $f(x)$ for a convex quadratic function while the other two provide a comparison for an embedded saddle. In all cases, we observe good agreements.
  • Figure 4: The two images at the top represent the comparison between AdamW and its SDE in terms of $f(x)$. The two at the bottom do the same for RMSpropW. In both cases, the first is a Transformer on MNIST and the second a ResNet on CIFAR-10: Our SDEs match the respective optimizers.
  • Figure 5: The loss predicted in Lemma \ref{['lemma:HPSR_AdamW_Insights']} matches the experimental results on a convex quadratic function. AdamW is run with regularization parameter $\theta = 1$. AdamW R (AdamW Rescaled) is run as we apply the scaling rule with $\kappa=2$. AdamW NR (AdamW Not Rescaled) is run as we apply the scaling rule with $\kappa=2$ on all hyperparameters but $\theta$, which is left unchanged: Our scaling rule holds, and failing to rescale $\theta$ leads the optimizer not to preserve the asymptotic loss level. The same happens for $\theta=4$ (Top-Left); The same for RMSpropW (Top-Right); For AdamW, $\beta_1$ and $\beta_2$ influence which basin will attract the dynamics and how fast this will converge, but not the asymptotic loss level inside the basin (Bottom-Left). For both AdamW and RMSpropW, the variance at convergence predicted in Lemma \ref{['lemma:AdamW_StaDistr_Insights']} matches the experimental results (Bottom-Right).
  • ...and 10 more figures

Theorems & Definitions (122)

  • Definition 3.1: Weak Approximation
  • Theorem 3.2: Informal Statement of Theorem \ref{['thm:SignSGD_SDE']}
  • proof : Proof idea
  • Corollary 3.3: Informal Statement of Corollary \ref{['thm:SignSGD_SDE_Simplified']}
  • Lemma 3.4
  • Lemma 3.5
  • proof : Proof idea
  • Lemma 3.6
  • Lemma 3.7
  • proof : Proof idea
  • ...and 112 more