Large Spikes in Stochastic Gradient Descent: A Large-Deviations View
Benjamin Gess, Daniel Heydecker
TL;DR
An explicit criterion separating two behaviours is identified: When an explicit function $G$ is positive, SGD produces large NTK-flattening spikes with high probability; when $G<0, their probability decays like $(n/\eta)^{-\vartheta/2}$, for an explicitly characterised $\vartheta\in (0,\infty)$.
Abstract
We analyse SGD training of a shallow, fully connected network in the NTK scaling and provide a quantitative theory of the catapult phase. We identify an explicit criterion separating two behaviours: When an explicit function $G$, depending only on the kernel, learning rate $η$ and data, is positive, SGD produces large NTK-flattening spikes with high probability; when $G<0$, their probability decays like $(n/η)^{-\vartheta/2}$, for an explicitly characterised $\vartheta\in (0,\infty)$. This yields a concrete parameter-dependent explanation for why such spikes may still be observed at practical widths.
