Table of Contents
Fetching ...

MoMo: Momentum Models for Adaptive Learning Rates

Fabian Schaipp, Ruben Ohana, Michael Eickenberg, Aaron Defazio, Robert M. Gower

TL;DR

MoMo introduces a model-based framework for momentum methods to compute adaptive learning rates by averaging past losses and gradients into a surrogate of the full loss $f(x)$ and applying a truncation via a known lower bound. The approach unifies momentum with Polyak-type step sizes and yields MoMo and MoMo-Adam, which can be combined with SGD-M and Adam (via preconditioning). The authors prove $\mathcal{O}(1/\sqrt{K})$ convergence for convex problems with interpolation and demonstrate increased robustness to hyperparameter tuning across images, text, and diffusion-model tasks, including online lower-bound estimation. The work potentially reduces the need for extensive learning-rate tuning in practice, enabling more reliable out-of-the-box performance for diverse models and datasets.

Abstract

Training a modern machine learning architecture on a new task requires extensive learning-rate tuning, which comes at a high computational cost. Here we develop new Polyak-type adaptive learning rates that can be used on top of any momentum method, and require less tuning to perform well. We first develop MoMo, a Momentum Model based adaptive learning rate for SGD-M (stochastic gradient descent with momentum). MoMo uses momentum estimates of the losses and gradients sampled at each iteration to build a model of the loss function. Our model makes use of any known lower bound of the loss function by using truncation, e.g. most losses are lower-bounded by zero. The model is then approximately minimized at each iteration to compute the next step. We show how MoMo can be used in combination with any momentum-based method, and showcase this by developing MoMo-Adam, which is Adam with our new model-based adaptive learning rate. We show that MoMo attains a $\mathcal{O}(1/\sqrt{K})$ convergence rate for convex problems with interpolation, needing knowledge of no problem-specific quantities other than the optimal value. Additionally, for losses with unknown lower bounds, we develop on-the-fly estimates of a lower bound, that are incorporated in our model. We show that MoMo and MoMo-Adam improve over SGD-M and Adam in terms of robustness to hyperparameter tuning for training image classifiers on MNIST, CIFAR, and Imagenet, for recommender systems on Criteo, for a transformer model on the translation task IWSLT14, and for a diffusion model.

MoMo: Momentum Models for Adaptive Learning Rates

TL;DR

MoMo introduces a model-based framework for momentum methods to compute adaptive learning rates by averaging past losses and gradients into a surrogate of the full loss and applying a truncation via a known lower bound. The approach unifies momentum with Polyak-type step sizes and yields MoMo and MoMo-Adam, which can be combined with SGD-M and Adam (via preconditioning). The authors prove convergence for convex problems with interpolation and demonstrate increased robustness to hyperparameter tuning across images, text, and diffusion-model tasks, including online lower-bound estimation. The work potentially reduces the need for extensive learning-rate tuning in practice, enabling more reliable out-of-the-box performance for diverse models and datasets.

Abstract

Training a modern machine learning architecture on a new task requires extensive learning-rate tuning, which comes at a high computational cost. Here we develop new Polyak-type adaptive learning rates that can be used on top of any momentum method, and require less tuning to perform well. We first develop MoMo, a Momentum Model based adaptive learning rate for SGD-M (stochastic gradient descent with momentum). MoMo uses momentum estimates of the losses and gradients sampled at each iteration to build a model of the loss function. Our model makes use of any known lower bound of the loss function by using truncation, e.g. most losses are lower-bounded by zero. The model is then approximately minimized at each iteration to compute the next step. We show how MoMo can be used in combination with any momentum-based method, and showcase this by developing MoMo-Adam, which is Adam with our new model-based adaptive learning rate. We show that MoMo attains a convergence rate for convex problems with interpolation, needing knowledge of no problem-specific quantities other than the optimal value. Additionally, for losses with unknown lower bounds, we develop on-the-fly estimates of a lower bound, that are incorporated in our model. We show that MoMo and MoMo-Adam improve over SGD-M and Adam in terms of robustness to hyperparameter tuning for training image classifiers on MNIST, CIFAR, and Imagenet, for recommender systems on Criteo, for a transformer model on the translation task IWSLT14, and for a diffusion model.
Paper Structure (47 sections, 15 theorems, 103 equations, 13 figures, 1 table, 6 algorithms)

This paper contains 47 sections, 15 theorems, 103 equations, 13 figures, 1 table, 6 algorithms.

Key Result

lemma 0

[MoMo update] Let Using model eq:posmodel, the closed form solution to eq:modelbasedVI is

Figures (13)

  • Figure 1: Illustration of the MoMo model (blue curves) for two different loss functions with $\alpha_k =5$. Due to truncation, the new iterate of MoMo (blue point) is closer to the minimum than SGD-M (orange point). The right plot shows how MoMo takes a small step when gradients are steep, whereas SGD-M takes a large step and ends up far from the solution.
  • Figure 2: Training loss and validation accuracy after a fixed number of epochs, for varying (constant) learning rate $\alpha_0$. Shaded area depicts two standard deviations.
  • Figure 3: ViT for Imagenet-1k. Left: Final validation set accuracy (top-1) for different learning-rate values $\alpha_\text{base}$. Right: Training curves for the three best values of $\alpha_\text{base}$ for both methods.
  • Figure 4: Validation accuracy over a range of learning rates $\alpha_0$. (a) Imagenet32 without weight decay ($\lambda=0$). (b) Left: IWSLT14 translation task with dropout 0.1 (plain) or 0.3 (dashed). Right: Learning rate schedule (black) and adaptive step sizes (grey dots) of MoMo-Adam$^*$ for $\alpha_\text{base}=10^{-3}$.
  • Figure 6.1: Validation score over training, we plot, for each method, the three choices of $\alpha_0$ that lead to the best validation score (compare to \ref{['fig:stability_val_score', 'fig:stability_appendix']}).
  • ...and 8 more figures

Theorems & Definitions (28)

  • lemma 0
  • remark 1
  • remark 2: Complexity
  • lemma 2
  • lemma 2
  • lemma 2
  • theorem 2
  • lemma 3
  • proof
  • remark 4
  • ...and 18 more