Table of Contents
Fetching ...

Gradient Descent with Polyak's Momentum Finds Flatter Minima via Large Catapults

Prin Phunyaphibarn, Junghyun Lee, Bohan Wang, Huishuai Zhang, Chulhee Yun

TL;DR

It is empirically shown that for linear diagonal networks and nonlinear neural networks, momentum gradient descent with a large learning rate displays large catapults, driving the iterates towards much flatter minima than those found by gradient descent.

Abstract

Although gradient descent with Polyak's momentum is widely used in modern machine and deep learning, a concrete understanding of its effects on the training trajectory remains elusive. In this work, we empirically show that for linear diagonal networks and nonlinear neural networks, momentum gradient descent with a large learning rate displays large catapults, driving the iterates towards much flatter minima than those found by gradient descent. We hypothesize that the large catapult is caused by momentum "prolonging" the self-stabilization effect (Damian et al., 2023). We provide theoretical and empirical support for our hypothesis in a simple toy example and empirical evidence supporting our hypothesis for linear diagonal networks.

Gradient Descent with Polyak's Momentum Finds Flatter Minima via Large Catapults

TL;DR

It is empirically shown that for linear diagonal networks and nonlinear neural networks, momentum gradient descent with a large learning rate displays large catapults, driving the iterates towards much flatter minima than those found by gradient descent.

Abstract

Although gradient descent with Polyak's momentum is widely used in modern machine and deep learning, a concrete understanding of its effects on the training trajectory remains elusive. In this work, we empirically show that for linear diagonal networks and nonlinear neural networks, momentum gradient descent with a large learning rate displays large catapults, driving the iterates towards much flatter minima than those found by gradient descent. We hypothesize that the large catapult is caused by momentum "prolonging" the self-stabilization effect (Damian et al., 2023). We provide theoretical and empirical support for our hypothesis in a simple toy example and empirical evidence supporting our hypothesis for linear diagonal networks.
Paper Structure (43 sections, 6 theorems, 36 equations, 17 figures)

This paper contains 43 sections, 6 theorems, 36 equations, 17 figures.

Key Result

Lemma 4.1

$u_\infty := \lim_{t \rightarrow \infty} u_t \leq u_{t+1} \leq u_t$ for all $t \geq 0$. Furthermore, if $\tau_0 := \inf\left\{ t \geq 0 : u_t < 0 \right\} < \infty$, then we have that $u_\infty = \frac{u_{\tau_0} - \beta u_{\tau_0 - 1}}{1 - \beta}$.

Figures (17)

  • Figure 1: Experiments following the same setting as nacson2022stepsize. In (a) and (b), "$\ell_1$ baseline" and "$\ell_2$ baseline" respectively stand for the solution with the minimal $\ell_1$ norm and the solution with the minimal $\ell_2$ norm to the regression problem. We use $\beta=0.9$ for PHB.
  • Figure 2: Neural Networks trained on (a-c) 1k and (d-f) 5k subset of CIFAR10 krizhevsky2009learning. For (a,b,d,e), we use the MSE loss, and for (c,f) we the CE loss. All FCNs are 3-layer and use ReLU activation. The shaded region is the linear warmup period.
  • Figure 3: (Left) Trajectories of GD, PHB, GD $\to$ PHB, PHB $\to$ GD with $\beta=0.9$, $\eta = (2+\epsilon)/u_0^2$ where $\epsilon=0.01$, $(u_0, v_0)=(10, 10^{-6})$, and no warmup. (Right) The self-stabilization stages for GD are highlighted and labeled in the sharpness plot. The MSS is shown as the red dotted line.
  • Figure 4: Numerical verification of Theorem \ref{['thm:toy']}.
  • Figure 5: Empirical Validation of Hypothesis \ref{['hypothesis']} and theory using LDNs
  • ...and 12 more figures

Theorems & Definitions (12)

  • Remark 2.1: Learning Rate Warmup
  • Remark 2.2: Role of Warmup in Large Catapults
  • Lemma 4.1
  • Theorem 4.1
  • Theorem 4.2: Informal
  • proof : Proof sketch.
  • Theorem B.1
  • proof : Proof of Theorem \ref{['thm:toy-GD-full']}
  • Lemma B.1
  • proof : Proof of Lemma \ref{['lem:Pt-exp-bound']}
  • ...and 2 more