Promoting Exploration in Memory-Augmented Adam using Critical Momenta
Pranshu Malviya, Gonçalo Mordido, Aristide Baratin, Reza Babanezhad Harikandeh, Jerry Huang, Simon Lacoste-Julien, Razvan Pascanu, Sarath Chandar
TL;DR
This work addresses the generalization gap of adaptive optimizers by promoting exploration of the loss landscape. It introduces Adam+CM, a memory-augmented variant of Adam that stores a buffer of critical momenta and aggregates them with the current momentum to encourage overshoot into flatter minima. Theoretical insights are provided under simplified quadratic loss assumptions, and extensive experiments on language modeling, image classification, and online learning show Adam+CM often outperforms Adam, CG-based variants, and SAM in terms of flatness and generalization. The approach offers practical performance benefits and can be integrated with existing optimizers to enhance exploration without excessive memory overhead.
Abstract
Adaptive gradient-based optimizers, notably Adam, have left their mark in training large-scale deep learning models, offering fast convergence and robustness to hyperparameter settings. However, they often struggle with generalization, attributed to their tendency to converge to sharp minima in the loss landscape. To address this, we propose a new memory-augmented version of Adam that encourages exploration towards flatter minima by incorporating a buffer of critical momentum terms during training. This buffer prompts the optimizer to overshoot beyond narrow minima, promoting exploration. Through comprehensive analysis in simple settings, we illustrate the efficacy of our approach in increasing exploration and bias towards flatter minima. We empirically demonstrate that it can improve model performance for image classification on ImageNet and CIFAR10/100, language modelling on Penn Treebank, and online learning tasks on TinyImageNet and 5-dataset. Our code is available at \url{https://github.com/chandar-lab/CMOptimizer}.
