Table of Contents
Fetching ...

Large Spikes in Stochastic Gradient Descent: A Large-Deviations View

Benjamin Gess, Daniel Heydecker

TL;DR

An explicit criterion separating two behaviours is identified: When an explicit function $G$ is positive, SGD produces large NTK-flattening spikes with high probability; when $G<0, their probability decays like $(n/\eta)^{-\vartheta/2}$, for an explicitly characterised $\vartheta\in (0,\infty)$.

Abstract

We analyse SGD training of a shallow, fully connected network in the NTK scaling and provide a quantitative theory of the catapult phase. We identify an explicit criterion separating two behaviours: When an explicit function $G$, depending only on the kernel, learning rate $η$ and data, is positive, SGD produces large NTK-flattening spikes with high probability; when $G<0$, their probability decays like $(n/η)^{-\vartheta/2}$, for an explicitly characterised $\vartheta\in (0,\infty)$. This yields a concrete parameter-dependent explanation for why such spikes may still be observed at practical widths.

Large Spikes in Stochastic Gradient Descent: A Large-Deviations View

TL;DR

An explicit criterion separating two behaviours is identified: When an explicit function is positive, SGD produces large NTK-flattening spikes with high probability; when (n/\eta)^{-\vartheta/2}\vartheta\in (0,\infty)$.

Abstract

We analyse SGD training of a shallow, fully connected network in the NTK scaling and provide a quantitative theory of the catapult phase. We identify an explicit criterion separating two behaviours: When an explicit function , depending only on the kernel, learning rate and data, is positive, SGD produces large NTK-flattening spikes with high probability; when , their probability decays like , for an explicitly characterised . This yields a concrete parameter-dependent explanation for why such spikes may still be observed at practical widths.
Paper Structure (48 sections, 10 theorems, 184 equations, 3 figures)

This paper contains 48 sections, 10 theorems, 184 equations, 3 figures.

Key Result

Theorem 1

Let $\varphi$ be the linear activation $\varphi(w)=w$, and consider the range Define, for any $\lambda$ in this range, which we allow to take the value $-\infty$ if $\eta\lambda s_i^2=1$ for some $i$, and Then

Figures (3)

  • Figure 1: Extension of the phase diagram zhu2022quadratic, which corresponds to cases (i-ii). Theorem \ref{['thrm: informal']} unveils a rich internal structure of the catapult region for the nonlinear, nondeterministic dynamics \ref{['eq: SGD']}, illustrated with green ('inflationary', case a of Theorem \ref{['thrm: informal']}) and pale blue ('deflationary', case b) stripes in (iii). In general, the critical and maximal curvatures $\lambda^{\rm MB}_{\rm crit, max}$ for minibatching are strictly smaller than their full-batch counterparts $\lambda^{\rm FB}_{\rm crit, max}$.
  • Figure 2: Plots of $G(\lambda)$ for the examples (\ref{['eq: example 1']} - \ref{['eq: example 2']})
  • Figure 3: $\max(1,\vartheta(\lambda))$ (blue) and $n^{-\vartheta(\lambda)/2}$ with $n=10^{12}$ (red) for the dataset \ref{['eq: dataset nonmonotone alpha']}.

Theorems & Definitions (23)

  • Theorem 1
  • Theorem 2
  • Remark 1.2
  • Theorem 3
  • Remark 3.1
  • Theorem 4
  • Lemma 4.1: Hitting Probabilities, Deflationary Case
  • proof
  • Proposition 4.2: Slow Escape is Exponentially Improbable
  • proof
  • ...and 13 more