MoMo: Momentum Models for Adaptive Learning Rates
Fabian Schaipp, Ruben Ohana, Michael Eickenberg, Aaron Defazio, Robert M. Gower
TL;DR
MoMo introduces a model-based framework for momentum methods to compute adaptive learning rates by averaging past losses and gradients into a surrogate of the full loss $f(x)$ and applying a truncation via a known lower bound. The approach unifies momentum with Polyak-type step sizes and yields MoMo and MoMo-Adam, which can be combined with SGD-M and Adam (via preconditioning). The authors prove $\mathcal{O}(1/\sqrt{K})$ convergence for convex problems with interpolation and demonstrate increased robustness to hyperparameter tuning across images, text, and diffusion-model tasks, including online lower-bound estimation. The work potentially reduces the need for extensive learning-rate tuning in practice, enabling more reliable out-of-the-box performance for diverse models and datasets.
Abstract
Training a modern machine learning architecture on a new task requires extensive learning-rate tuning, which comes at a high computational cost. Here we develop new Polyak-type adaptive learning rates that can be used on top of any momentum method, and require less tuning to perform well. We first develop MoMo, a Momentum Model based adaptive learning rate for SGD-M (stochastic gradient descent with momentum). MoMo uses momentum estimates of the losses and gradients sampled at each iteration to build a model of the loss function. Our model makes use of any known lower bound of the loss function by using truncation, e.g. most losses are lower-bounded by zero. The model is then approximately minimized at each iteration to compute the next step. We show how MoMo can be used in combination with any momentum-based method, and showcase this by developing MoMo-Adam, which is Adam with our new model-based adaptive learning rate. We show that MoMo attains a $\mathcal{O}(1/\sqrt{K})$ convergence rate for convex problems with interpolation, needing knowledge of no problem-specific quantities other than the optimal value. Additionally, for losses with unknown lower bounds, we develop on-the-fly estimates of a lower bound, that are incorporated in our model. We show that MoMo and MoMo-Adam improve over SGD-M and Adam in terms of robustness to hyperparameter tuning for training image classifiers on MNIST, CIFAR, and Imagenet, for recommender systems on Criteo, for a transformer model on the translation task IWSLT14, and for a diffusion model.
