A Theoretical Framework for Grokking: Interpolation followed by Riemannian Norm Minimisation
Etienne Boursier, Scott Pesme, Radu-Alexandru Dragomir
TL;DR
This paper studies gradient flow with weight decay on a general loss $F:\mathbb{R}^d\to\mathbb{R}$ and shows that, as $\lambda\to 0$, the trajectory decomposes into two coupled phases: a fast phase that follows the unregularised gradient flow to a manifold $\mathcal{M}$ of stationary points, and a slow phase that, on the slow timescale, drifts along $\mathcal{M}$ to minimise the $\ell_2$-norm via a Riemannian gradient flow. This optimisation-based two-timescale picture provides a principled explanation for grokking: an early rapid reduction of training loss to zero can be followed by a delayed generalisation improvement driven by norm reduction along the interpolation manifold. The authors formalise the fast and slow dynamics, establish convergence results to a constrained minimum on $\mathcal{M}$, and illustrate the mechanism on linear regression, matrix completion, and (non-smooth) ReLU architectures in synthetic experiments. The work highlights the role of norm-based regularisation and interpolation manifolds in delayed generalisation, offering a framework for understanding and extending grokking beyond NTK regimes and into broader loss landscapes.
Abstract
We study the dynamics of gradient flow with small weight decay on general training losses $F: \mathbb{R}^d \to \mathbb{R}$. Under mild regularity assumptions and assuming convergence of the unregularised gradient flow, we show that the trajectory with weight decay $λ$ exhibits a two-phase behaviour as $λ\to 0$. During the initial fast phase, the trajectory follows the unregularised gradient flow and converges to a manifold of critical points of $F$. Then, at time of order $1/λ$, the trajectory enters a slow drift phase and follows a Riemannian gradient flow minimising the $\ell_2$-norm of the parameters. This purely optimisation-based phenomenon offers a natural explanation for the \textit{grokking} effect observed in deep learning, where the training loss rapidly reaches zero while the test loss plateaus for an extended period before suddenly improving. We argue that this generalisation jump can be attributed to the slow norm reduction induced by weight decay, as explained by our analysis. We validate this mechanism empirically on several synthetic regression tasks.
