Table of Contents
Fetching ...

Improving Generalization Performance by Switching from Adam to SGD

Nitish Shirish Keskar, Richard Socher

TL;DR

This work tackles the generalization disparity between adaptive optimizers like Adam and SGD by introducing SWATS, a hybrid method that starts with Adam and automatically switches to SGD based on a gradient-subspace projection criterion. The switch employs a closed-form SGD learning rate derived from projecting the Adam step onto the gradient and stabilizes this estimate via an exponential average. Across diverse benchmarks (CIFAR-10/100, Tiny-ImageNet, PTB, WT2), SWATS often matches or surpasses the best of SGD or Adam without adding hyperparameters, effectively closing the generalization gap in many tasks. The findings highlight the potential of hybrid optimization strategies to balance rapid initial progress with robust generalization in deep learning.

Abstract

Despite superior training outcomes, adaptive optimization methods such as Adam, Adagrad or RMSprop have been found to generalize poorly compared to Stochastic gradient descent (SGD). These methods tend to perform well in the initial portion of training but are outperformed by SGD at later stages of training. We investigate a hybrid strategy that begins training with an adaptive method and switches to SGD when appropriate. Concretely, we propose SWATS, a simple strategy which switches from Adam to SGD when a triggering condition is satisfied. The condition we propose relates to the projection of Adam steps on the gradient subspace. By design, the monitoring process for this condition adds very little overhead and does not increase the number of hyperparameters in the optimizer. We report experiments on several standard benchmarks such as: ResNet, SENet, DenseNet and PyramidNet for the CIFAR-10 and CIFAR-100 data sets, ResNet on the tiny-ImageNet data set and language modeling with recurrent networks on the PTB and WT2 data sets. The results show that our strategy is capable of closing the generalization gap between SGD and Adam on a majority of the tasks.

Improving Generalization Performance by Switching from Adam to SGD

TL;DR

This work tackles the generalization disparity between adaptive optimizers like Adam and SGD by introducing SWATS, a hybrid method that starts with Adam and automatically switches to SGD based on a gradient-subspace projection criterion. The switch employs a closed-form SGD learning rate derived from projecting the Adam step onto the gradient and stabilizes this estimate via an exponential average. Across diverse benchmarks (CIFAR-10/100, Tiny-ImageNet, PTB, WT2), SWATS often matches or surpasses the best of SGD or Adam without adding hyperparameters, effectively closing the generalization gap in many tasks. The findings highlight the potential of hybrid optimization strategies to balance rapid initial progress with robust generalization in deep learning.

Abstract

Despite superior training outcomes, adaptive optimization methods such as Adam, Adagrad or RMSprop have been found to generalize poorly compared to Stochastic gradient descent (SGD). These methods tend to perform well in the initial portion of training but are outperformed by SGD at later stages of training. We investigate a hybrid strategy that begins training with an adaptive method and switches to SGD when appropriate. Concretely, we propose SWATS, a simple strategy which switches from Adam to SGD when a triggering condition is satisfied. The condition we propose relates to the projection of Adam steps on the gradient subspace. By design, the monitoring process for this condition adds very little overhead and does not increase the number of hyperparameters in the optimizer. We report experiments on several standard benchmarks such as: ResNet, SENet, DenseNet and PyramidNet for the CIFAR-10 and CIFAR-100 data sets, ResNet on the tiny-ImageNet data set and language modeling with recurrent networks on the PTB and WT2 data sets. The results show that our strategy is capable of closing the generalization gap between SGD and Adam on a majority of the tasks.

Paper Structure

This paper contains 6 sections, 14 equations, 7 figures, 1 table, 1 algorithm.

Figures (7)

  • Figure 1: Training the DenseNet architecture on the CIFAR-10 data set with four optimizers: SGD, Adam, Adam-Clip$(1,\infty)$ and Adam-Clip$(0, 1)$. SGD achieves the best testing accuracy while training with Adam leads to a generalization gap of roughly $2\%$. Setting a minimum learning rate for each parameter of Adam partially closes the generalization gap.
  • Figure 2: Training the DenseNet architecture on the CIFAR-10 data set using Adam and switching to SGD with learning rate with learning rate $0.1$ and momentum $0.9$ after $(10,40,80)$ epochs; the switchover point is denoted by Sw@ in the figure. Switching early enables the model to achieve testing accuracy comparable to SGD but switching too late in the training process leads to a generalization gap similar to Adam.
  • Figure 3: Illustrating the learning rate for SGD ($\gamma_k$) estimated by our proposed projection given an iterate $w_k$, a stochastic gradient $g_k$ and the Adam step $p_k$.
  • Figure 4: Numerical experiments comparing SGD(M), Adam and SWATS with tuned learning rates on the ResNet-32, DenseNet, PyramidNet and SENet architectures on CIFAR-10 and CIFAR-100 data sets.
  • Figure 5: Numerical experiments comparing SGD(M), Adam and SWATS with tuned learning rates on the ResNet-18 architecture on the Tiny-ImageNet data set.
  • ...and 2 more figures