Table of Contents
Fetching ...

Promoting Exploration in Memory-Augmented Adam using Critical Momenta

Pranshu Malviya, Gonçalo Mordido, Aristide Baratin, Reza Babanezhad Harikandeh, Jerry Huang, Simon Lacoste-Julien, Razvan Pascanu, Sarath Chandar

TL;DR

This work addresses the generalization gap of adaptive optimizers by promoting exploration of the loss landscape. It introduces Adam+CM, a memory-augmented variant of Adam that stores a buffer of critical momenta and aggregates them with the current momentum to encourage overshoot into flatter minima. Theoretical insights are provided under simplified quadratic loss assumptions, and extensive experiments on language modeling, image classification, and online learning show Adam+CM often outperforms Adam, CG-based variants, and SAM in terms of flatness and generalization. The approach offers practical performance benefits and can be integrated with existing optimizers to enhance exploration without excessive memory overhead.

Abstract

Adaptive gradient-based optimizers, notably Adam, have left their mark in training large-scale deep learning models, offering fast convergence and robustness to hyperparameter settings. However, they often struggle with generalization, attributed to their tendency to converge to sharp minima in the loss landscape. To address this, we propose a new memory-augmented version of Adam that encourages exploration towards flatter minima by incorporating a buffer of critical momentum terms during training. This buffer prompts the optimizer to overshoot beyond narrow minima, promoting exploration. Through comprehensive analysis in simple settings, we illustrate the efficacy of our approach in increasing exploration and bias towards flatter minima. We empirically demonstrate that it can improve model performance for image classification on ImageNet and CIFAR10/100, language modelling on Penn Treebank, and online learning tasks on TinyImageNet and 5-dataset. Our code is available at \url{https://github.com/chandar-lab/CMOptimizer}.

Promoting Exploration in Memory-Augmented Adam using Critical Momenta

TL;DR

This work addresses the generalization gap of adaptive optimizers by promoting exploration of the loss landscape. It introduces Adam+CM, a memory-augmented variant of Adam that stores a buffer of critical momenta and aggregates them with the current momentum to encourage overshoot into flatter minima. Theoretical insights are provided under simplified quadratic loss assumptions, and extensive experiments on language modeling, image classification, and online learning show Adam+CM often outperforms Adam, CG-based variants, and SAM in terms of flatness and generalization. The approach offers practical performance benefits and can be integrated with existing optimizers to enhance exploration without excessive memory overhead.

Abstract

Adaptive gradient-based optimizers, notably Adam, have left their mark in training large-scale deep learning models, offering fast convergence and robustness to hyperparameter settings. However, they often struggle with generalization, attributed to their tendency to converge to sharp minima in the loss landscape. To address this, we propose a new memory-augmented version of Adam that encourages exploration towards flatter minima by incorporating a buffer of critical momentum terms during training. This buffer prompts the optimizer to overshoot beyond narrow minima, promoting exploration. Through comprehensive analysis in simple settings, we illustrate the efficacy of our approach in increasing exploration and bias towards flatter minima. We empirically demonstrate that it can improve model performance for image classification on ImageNet and CIFAR10/100, language modelling on Penn Treebank, and online learning tasks on TinyImageNet and 5-dataset. Our code is available at \url{https://github.com/chandar-lab/CMOptimizer}.
Paper Structure (25 sections, 13 equations, 22 figures, 12 tables, 1 algorithm)

This paper contains 25 sections, 13 equations, 22 figures, 12 tables, 1 algorithm.

Figures (22)

  • Figure 1: (Left) Learning trajectories for different optimizers on the Goldstein-Price loss function starting from different initial points. While the other optimizers get stuck in sub-optimal surfaces, Adam+CM explores a lower loss surface and is able to reach the global minimum. (Right) Pseudo-code for Adam with critical momenta (Adam+CM).
  • Figure 2: First 10 steps of Adam+CG and Adam+CM trajectories on Ackley loss surface. Coloured diamonds represent the final points reached by the optimizers. Gradient cancellation is observed in Adam+CG as buffer mean and new gradients cancel each other out, yielding a small update. Conversely, Adam+CM escapes sub-optimal minima and converges near the global minimum.
  • Figure 3: Quadratic convergence rates ($1-\rho^\ast$) of classical momentum and critical momenta. Solid curves indicate that both $\alpha$ and $\beta$ were optimized to obtain $\rho^\ast$, while dashed lines indicate that $\rho^\ast$ obtained with $\beta=0.9$. Critical momenta converges for a wide range of condition numbers in both cases.
  • Figure 4: Training loss curves (left, averaged across $10$ seeds) and learning trajectories (right, one seed) for different optimizers on the Ackley loss surface. While the other optimizers get stuck in sub-optimal minima near the initialisation point (black square), both CM variants explore and find the lower loss surface near the global solution (black diamond).
  • Figure 5: Optimization trajectory of Adam (left), Adam+CG (middle), and Adam+CM (right) on a toy 1D function with a flat and a sharp minimum with increasing sharpness (across columns), for different initialisation points (across rows). Green backgrounds indicate that the optimizer escapes the sharper minimum while red backgrounds indicate otherwise. The vertical line indicates the final point in each sub-figure. We observe that Adam mostly converges to the minimum closest to the initial point. Adam+CM converges to the flatter minimum for different initial points and degrees of sharpness more often than Adam+CG.
  • ...and 17 more figures