Table of Contents
Fetching ...

Noise Is Not the Main Factor Behind the Gap Between SGD and Adam on Transformers, but Sign Descent Might Be

Frederik Kunstner, Jacques Chen, Jonathan Wilder Lavington, Mark Schmidt

TL;DR

Problem: Why does Adam outperform SGD on transformer models, and is stochastic noise responsible? Approach: systematically vary batch size from small to full, comparing SGD and Adam with and without momentum, and testing sign-descent-like and normalized-update variants. Findings: removing stochasticity does not close the gap; Adam's advantage grows with batch size, and sign descent with momentum closes much of the gap in full batch; normalization helps SGD but does not fully match Adam. Significance: reveals that deterministic optimization dynamics largely shape the performance gap and highlights a potential path to simpler, sign-based methods to analyze and improve training of transformers.

Abstract

The success of the Adam optimizer on a wide array of architectures has made it the default in settings where stochastic gradient descent (SGD) performs poorly. However, our theoretical understanding of this discrepancy is lagging, preventing the development of significant improvements on either algorithm. Recent work advances the hypothesis that Adam and other heuristics like gradient clipping outperform SGD on language tasks because the distribution of the error induced by sampling has heavy tails. This suggests that Adam outperform SGD because it uses a more robust gradient estimate. We evaluate this hypothesis by varying the batch size, up to the entire dataset, to control for stochasticity. We present evidence that stochasticity and heavy-tailed noise are not major factors in the performance gap between SGD and Adam. Rather, Adam performs better as the batch size increases, while SGD is less effective at taking advantage of the reduction in noise. This raises the question as to why Adam outperforms SGD in the full-batch setting. Through further investigation of simpler variants of SGD, we find that the behavior of Adam with large batches is similar to sign descent with momentum.

Noise Is Not the Main Factor Behind the Gap Between SGD and Adam on Transformers, but Sign Descent Might Be

TL;DR

Problem: Why does Adam outperform SGD on transformer models, and is stochastic noise responsible? Approach: systematically vary batch size from small to full, comparing SGD and Adam with and without momentum, and testing sign-descent-like and normalized-update variants. Findings: removing stochasticity does not close the gap; Adam's advantage grows with batch size, and sign descent with momentum closes much of the gap in full batch; normalization helps SGD but does not fully match Adam. Significance: reveals that deterministic optimization dynamics largely shape the performance gap and highlights a potential path to simpler, sign-based methods to analyze and improve training of transformers.

Abstract

The success of the Adam optimizer on a wide array of architectures has made it the default in settings where stochastic gradient descent (SGD) performs poorly. However, our theoretical understanding of this discrepancy is lagging, preventing the development of significant improvements on either algorithm. Recent work advances the hypothesis that Adam and other heuristics like gradient clipping outperform SGD on language tasks because the distribution of the error induced by sampling has heavy tails. This suggests that Adam outperform SGD because it uses a more robust gradient estimate. We evaluate this hypothesis by varying the batch size, up to the entire dataset, to control for stochasticity. We present evidence that stochasticity and heavy-tailed noise are not major factors in the performance gap between SGD and Adam. Rather, Adam performs better as the batch size increases, while SGD is less effective at taking advantage of the reduction in noise. This raises the question as to why Adam outperforms SGD in the full-batch setting. Through further investigation of simpler variants of SGD, we find that the behavior of Adam with large batches is similar to sign descent with momentum.
Paper Structure (33 sections, 6 equations, 22 figures, 2 tables, 2 algorithms)

This paper contains 33 sections, 6 equations, 22 figures, 2 tables, 2 algorithms.

Figures (22)

  • Figure 1: The Heavy-Tail hypothesis: the gap between SGD and Adam is caused by a heavier tail in the distribution of the stochastic gradient error. The performance gap between SGD and Adam is larger and more consistent on transformers on text data (right: PTB, Wikitext2, SQuAD) than on CNNs on image data (left: MNIST, CIFAR-10), which coincides with a heavier tail in the distribution of the stochastic gradient error. zhang2020adaptiveattention hypothesize that heavier tails might be the cause of this gap. Top: Distribution of errors in stochastic gradients at initialization ($\mathopen{}\mathclose{\left\Vert g-\tilde{g}\right\Vert$ where $\tilde{g}}$ is stochastic and $g$ is a full gradient) compared against a Gaussian (QQ-plot). Bottom: SGD and Adam with and without momentum ($+$m/$-$m) with small batch sizes.
  • Figure 2: The gap does not disappear when training in full batch. Repeating the training procedure of \ref{['fig:histogram']} in full batch reveals a similar---or larger---gap between SGD and Adam.
  • Figure 3: The gap between SGD and Adam increases with batch size. Performance after a similar number of iterations across batch sizes. The gap between Adam and SGD grows with batch size on language models, confirming the trend observed in \ref{['fig:histogram', 'fig:full-batch']}. Due to practical implementation issues, smaller batch sizes still run for more iterations on the larger datasets (WikiText-2, SQuAD) despite being stopped after one epoch (see \ref{['apx:experimental-details']}). The degradation in performance as the batch size increases is explained by this decrease in number of iterations. We still observe that the gap grows with batch size despite this bias favoring small batches. To show the effect of batch size beyond the first epoch, we show the full runs in \ref{['fig:interpolation']}.
  • Figure 4: Adam better takes advantage of the reduced noise with larger batch sizes. For each problem, each optimizer is shown with five different batch sizes, interpolating between the small batch setting of \ref{['fig:histogram']} and the full batch setting of \ref{['fig:full-batch']}. Larger batch sizes run for fewer iterations and terminate earlier, indicated by the markers ($\blacktriangledown$), with smaller sizes for smaller batches. SGD follow a similar trajectory across most batch sizes, indicating that increasing the batch size past some threshold does not improve the performance; similar results would be expected from running with the same batch size but terminating earlier. Adam, on the other hand, achieves similar error in fewer iterations when the batch size is increased. Top/bottom: results without/with momentum.
  • Figure 5: Sign descent and normalized GD improve with batch size. Performance after a similar number of iterations across batch sizes. Sign descent performs poorly in small batch sizes and is outperformed by normalized GD. In full batch, sign descent outperforms normalized GD and approaches the performance of Adam. We show the full runs for each batch size in \ref{['fig:interpolation-normalized']}.
  • ...and 17 more figures