Table of Contents
Fetching ...

On the Generalization of Stochastic Gradient Descent with Momentum

Ali Ramezani-Kebrya, Kimon Antonakopoulos, Volkan Cevher, Ashish Khisti, Ben Liang

TL;DR

The paper addresses how momentum affects the generalization of stochastic gradient methods. It demonstrates that standard SGDM can have an unbounded stability gap on some convex losses, motivating SGDEM, a momentum scheme with early momentum that yields generalization guarantees for smooth Lipschitz losses and provides explicit convergence and true-risk bounds. The authors derive both expected and high-probability generalization bounds, show favorable behavior for SGDEM in multi-epoch training, and establish a strongly convex case where SGDM generalizes under suitable momentum. Empirically, SGDEM improves generalization and training performance on CIFAR-10 and distributed ImageNet training, illustrating practical impact for large-scale, parallel learning. Overall, the work clarifies a tunable balance between optimization efficiency and generalization risk when using momentum in SGD.

Abstract

While momentum-based accelerated variants of stochastic gradient descent (SGD) are widely used when training machine learning models, there is little theoretical understanding on the generalization error of such methods. In this work, we first show that there exists a convex loss function for which the stability gap for multiple epochs of SGD with standard heavy-ball momentum (SGDM) becomes unbounded. Then, for smooth Lipschitz loss functions, we analyze a modified momentum-based update rule, i.e., SGD with early momentum (SGDEM) under a broad range of step-sizes, and show that it can train machine learning models for multiple epochs with a guarantee for generalization. Finally, for the special case of strongly convex loss functions, we find a range of momentum such that multiple epochs of standard SGDM, as a special form of SGDEM, also generalizes. Extending our results on generalization, we also develop an upper bound on the expected true risk, in terms of the number of training steps, sample size, and momentum. Our experimental evaluations verify the consistency between the numerical results and our theoretical bounds. SGDEM improves the generalization error of SGDM when training ResNet-18 on ImageNet in practical distributed settings.

On the Generalization of Stochastic Gradient Descent with Momentum

TL;DR

The paper addresses how momentum affects the generalization of stochastic gradient methods. It demonstrates that standard SGDM can have an unbounded stability gap on some convex losses, motivating SGDEM, a momentum scheme with early momentum that yields generalization guarantees for smooth Lipschitz losses and provides explicit convergence and true-risk bounds. The authors derive both expected and high-probability generalization bounds, show favorable behavior for SGDEM in multi-epoch training, and establish a strongly convex case where SGDM generalizes under suitable momentum. Empirically, SGDEM improves generalization and training performance on CIFAR-10 and distributed ImageNet training, illustrating practical impact for large-scale, parallel learning. Overall, the work clarifies a tunable balance between optimization efficiency and generalization risk when using momentum in SGD.

Abstract

While momentum-based accelerated variants of stochastic gradient descent (SGD) are widely used when training machine learning models, there is little theoretical understanding on the generalization error of such methods. In this work, we first show that there exists a convex loss function for which the stability gap for multiple epochs of SGD with standard heavy-ball momentum (SGDM) becomes unbounded. Then, for smooth Lipschitz loss functions, we analyze a modified momentum-based update rule, i.e., SGD with early momentum (SGDEM) under a broad range of step-sizes, and show that it can train machine learning models for multiple epochs with a guarantee for generalization. Finally, for the special case of strongly convex loss functions, we find a range of momentum such that multiple epochs of standard SGDM, as a special form of SGDEM, also generalizes. Extending our results on generalization, we also develop an upper bound on the expected true risk, in terms of the number of training steps, sample size, and momentum. Our experimental evaluations verify the consistency between the numerical results and our theoretical bounds. SGDEM improves the generalization error of SGDM when training ResNet-18 on ImageNet in practical distributed settings.

Paper Structure

This paper contains 33 sections, 29 theorems, 152 equations, 14 figures, 1 table.

Key Result

Theorem 2

If $A$ is an $\epsilon_s$-uniformly stable algorithm, then the generalization error of $A$ is upper bounded by $\epsilon_s$.

Figures (14)

  • Figure 1: Validation loss and generalization error of SGDEM when training ResNet-18 Resnet on ImageNet ImageNet in a distributed setting with 4 GPUs under tuned step-size and global minibatch size of 128. For each $t_d$, the momentum is set to $\mu_d=0.9$ in the first $t_d$ epochs and then zero for the next $90-t_d$ epochs. SGDM is a special form of SGDEM with $t_d=90$. The details are provided in \ref{['sec:numerical']} and \ref{['app:exp']} .
  • Figure 2: Test error (left) and test accuracy (middle) of ResNet-20 on CIFAR10. Test error of a feedforward fully connected neural network for notMNIST dataset (right).
  • Figure 3: Validation accuracy and generalization gap of \ref{['earlyupdate']} when training ResNet-18 on ImageNet in a distributed setting with 4 GPUs under tuned step-size and global minibatch size of 128. For each $t_d$, the momentum is set to $\mu_d=0.9$ in the first $t_d$ epochs and then zero for the next $90-t_d$ epochs. \ref{['update']} is a special form of \ref{['earlyupdate']} with $t_d=90$.
  • Figure 4: Generalization error (left) and training error (middle) of logistic regression (cross entropy loss) for notMNIST dataset with $T=1000$ iterations . Test accuracy of logistic regression for notMNIST dataset with $n=500$ (right).
  • Figure 5: Test accuracy of a feedforward fully connected neural network for notMNIST dataset.
  • ...and 9 more figures

Theorems & Definitions (40)

  • Definition 1
  • Theorem 2: Hardt
  • Example 1
  • Theorem 3
  • Theorem 4
  • Lemma 5
  • Corollary 6
  • Remark 7
  • Theorem 8
  • Remark 9
  • ...and 30 more