Table of Contents
Fetching ...

Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients

Lukas Balles, Philipp Hennig

TL;DR

Adam can be viewed as a two-part update: a sign-directed direction and a variance-based magnitude scaling. The authors formalize Stochastic Sign Descent (SSD) and Stochastic Variance-Adapted Gradient (SVAG), and introduce MSVAG with momentum to combine these ideas. They show theoretically that sign-based updates can be advantageous in noisy, ill-conditioned problems, but empirically that the sign component largely drives Adam's behavior and its generalization drawbacks, whereas variance adaptation generally helps and can mitigate those drawbacks. Variance adaptation is presented as a general technique applicable beyond Adam, with MSVAG offering a practical alternative when sign-based methods fail. Across multiple tasks, the results suggest that while the sign can be problem-dependent in usefulness, variance adaptation improves convergence and generalization properties, providing a new toolkit for optimization in deep learning.

Abstract

The ADAM optimizer is exceedingly popular in the deep learning community. Often it works very well, sometimes it doesn't. Why? We interpret ADAM as a combination of two aspects: for each weight, the update direction is determined by the sign of stochastic gradients, whereas the update magnitude is determined by an estimate of their relative variance. We disentangle these two aspects and analyze them in isolation, gaining insight into the mechanisms underlying ADAM. This analysis also extends recent results on adverse effects of ADAM on generalization, isolating the sign aspect as the problematic one. Transferring the variance adaptation to SGD gives rise to a novel method, completing the practitioner's toolbox for problems where ADAM fails.

Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients

TL;DR

Adam can be viewed as a two-part update: a sign-directed direction and a variance-based magnitude scaling. The authors formalize Stochastic Sign Descent (SSD) and Stochastic Variance-Adapted Gradient (SVAG), and introduce MSVAG with momentum to combine these ideas. They show theoretically that sign-based updates can be advantageous in noisy, ill-conditioned problems, but empirically that the sign component largely drives Adam's behavior and its generalization drawbacks, whereas variance adaptation generally helps and can mitigate those drawbacks. Variance adaptation is presented as a general technique applicable beyond Adam, with MSVAG offering a practical alternative when sign-based methods fail. Across multiple tasks, the results suggest that while the sign can be problem-dependent in usefulness, variance adaptation improves convergence and generalization properties, providing a new toolkit for optimization in deep learning.

Abstract

The ADAM optimizer is exceedingly popular in the deep learning community. Often it works very well, sometimes it doesn't. Why? We interpret ADAM as a combination of two aspects: for each weight, the update direction is determined by the sign of stochastic gradients, whereas the update magnitude is determined by an estimate of their relative variance. We disentangle these two aspects and analyze them in isolation, gaining insight into the mechanisms underlying ADAM. This analysis also extends recent results on adverse effects of ADAM on generalization, isolating the sign aspect as the problematic one. Transferring the variance adaptation to SGD gives rise to a novel method, completing the practitioner's toolbox for problems where ADAM fails.

Paper Structure

This paper contains 48 sections, 6 theorems, 61 equations, 8 figures, 5 algorithms.

Key Result

Lemma 1

Let $\hat{p}\in\mathbb{R}^d$ be a random variable with $\mathbf{E}[\hat{p}]=p$ and $\mathbf{var}[p_i]=\sigma_i^2$. Then $\mathbf{E}[\Vert \gamma\odot\hat{p} - p\Vert_2^2]$ is minimized by and $\mathbf{E}[\Vert \gamma\odot\mathop{\mathrm{sign}}\nolimits(\hat{p}) - \mathop{\mathrm{sign}}\nolimits(p) \Vert_2^2]$ is minimized by where $\rho_i:=\mathbf{P}[\mathop{\mathrm{sign}}\nolimits(\hat{p}_i) =

Figures (8)

  • Figure 1: The methods under consideration in this paper. "m-" refers to the use of $m_t$ in place of $g_t$, which we colloquially refer to as the momentum variant. m-svag will be derived below.
  • Figure 2: Performance of sgd and ssd on stochastic quadratic problems. Rows correspond to different QPs: the eigenspectrum is shown and each is used with a randomly rotated and an axis-aligned eigenbasis. Columns correspond to different noise levels. The individual panels show function value over number of steps. On the well-conditioned problem, gradient descent vastly outperforms the sign-based method in the noise-free case, but the difference is evened out when noise is added. The orientation of the eigenbasis had little effect on the comparison in the well-conditioned case. On the ill-conditioned problem, the methods perform roughly equal when the eigenbasis is randomly rotated. ssd benefits drastically from an axis-aligned eigenbasis, where it clearly outperforms sgd.
  • Figure 3: Variance adaptation factors as functions of the relative standard deviation $\eta$. The optimal factor for the sign of a (Gaussian) stochastic gradient is $\operatorname{erf}[(\sqrt{2}\eta)^{-1}]$, which is closely approximated by $(1+\eta^2)^{-1/2}$, the factor implicitly employed by adam. $(1+\eta^2)^{-1}$ is the optimal factor for a stochastic gradient.
  • Figure 4: Conceptual sketch of variance-adapted stochastic gradients. The left panel shows the true gradient $\nabla\mathcal{L}=(2,1)$ and stochastic gradients scattered around it with $(\sigma_1, \sigma_2)=(1, 1.5)$. In the right panel, we scale the $i$-th coordinate by $(1+\eta_i^2)^{-1}$. In this example, the $\theta_2$-coordinate has much higher relative variance ($\eta_2^2 = 2.25$) than the $\theta_1$-coordinate ($\eta^2_1 = 0.25$) and is thus shortened. This reduces the variance of the update direction at the expense of biasing it away from the true gradient in expectation.
  • Figure 5: Experimental results on the four test problems. Plots display training loss and test accuracy over the number of steps. Curves for the different optimization methods are color-coded. The shaded area spans one standard deviation, obtained from ten replications with different random seeds.
  • ...and 3 more figures

Theorems & Definitions (11)

  • Lemma 1
  • Theorem 1
  • Lemma 2: Lemma 3.1 in Wilson2017
  • Lemma 3
  • Lemma 4
  • proof
  • proof : Proof of Lemma \ref{['lemma:optimal_va_factors']}
  • Lemma 5
  • proof
  • proof : Proof of Theorem \ref{['theorem:convergence_svag']}
  • ...and 1 more