Table of Contents
Fetching ...

A Theoretical Framework for Grokking: Interpolation followed by Riemannian Norm Minimisation

Etienne Boursier, Scott Pesme, Radu-Alexandru Dragomir

TL;DR

This paper studies gradient flow with weight decay on a general loss $F:\mathbb{R}^d\to\mathbb{R}$ and shows that, as $\lambda\to 0$, the trajectory decomposes into two coupled phases: a fast phase that follows the unregularised gradient flow to a manifold $\mathcal{M}$ of stationary points, and a slow phase that, on the slow timescale, drifts along $\mathcal{M}$ to minimise the $\ell_2$-norm via a Riemannian gradient flow. This optimisation-based two-timescale picture provides a principled explanation for grokking: an early rapid reduction of training loss to zero can be followed by a delayed generalisation improvement driven by norm reduction along the interpolation manifold. The authors formalise the fast and slow dynamics, establish convergence results to a constrained minimum on $\mathcal{M}$, and illustrate the mechanism on linear regression, matrix completion, and (non-smooth) ReLU architectures in synthetic experiments. The work highlights the role of norm-based regularisation and interpolation manifolds in delayed generalisation, offering a framework for understanding and extending grokking beyond NTK regimes and into broader loss landscapes.

Abstract

We study the dynamics of gradient flow with small weight decay on general training losses $F: \mathbb{R}^d \to \mathbb{R}$. Under mild regularity assumptions and assuming convergence of the unregularised gradient flow, we show that the trajectory with weight decay $λ$ exhibits a two-phase behaviour as $λ\to 0$. During the initial fast phase, the trajectory follows the unregularised gradient flow and converges to a manifold of critical points of $F$. Then, at time of order $1/λ$, the trajectory enters a slow drift phase and follows a Riemannian gradient flow minimising the $\ell_2$-norm of the parameters. This purely optimisation-based phenomenon offers a natural explanation for the \textit{grokking} effect observed in deep learning, where the training loss rapidly reaches zero while the test loss plateaus for an extended period before suddenly improving. We argue that this generalisation jump can be attributed to the slow norm reduction induced by weight decay, as explained by our analysis. We validate this mechanism empirically on several synthetic regression tasks.

A Theoretical Framework for Grokking: Interpolation followed by Riemannian Norm Minimisation

TL;DR

This paper studies gradient flow with weight decay on a general loss and shows that, as , the trajectory decomposes into two coupled phases: a fast phase that follows the unregularised gradient flow to a manifold of stationary points, and a slow phase that, on the slow timescale, drifts along to minimise the -norm via a Riemannian gradient flow. This optimisation-based two-timescale picture provides a principled explanation for grokking: an early rapid reduction of training loss to zero can be followed by a delayed generalisation improvement driven by norm reduction along the interpolation manifold. The authors formalise the fast and slow dynamics, establish convergence results to a constrained minimum on , and illustrate the mechanism on linear regression, matrix completion, and (non-smooth) ReLU architectures in synthetic experiments. The work highlights the role of norm-based regularisation and interpolation manifolds in delayed generalisation, offering a framework for understanding and extending grokking beyond NTK regimes and into broader loss landscapes.

Abstract

We study the dynamics of gradient flow with small weight decay on general training losses . Under mild regularity assumptions and assuming convergence of the unregularised gradient flow, we show that the trajectory with weight decay exhibits a two-phase behaviour as . During the initial fast phase, the trajectory follows the unregularised gradient flow and converges to a manifold of critical points of . Then, at time of order , the trajectory enters a slow drift phase and follows a Riemannian gradient flow minimising the -norm of the parameters. This purely optimisation-based phenomenon offers a natural explanation for the \textit{grokking} effect observed in deep learning, where the training loss rapidly reaches zero while the test loss plateaus for an extended period before suddenly improving. We argue that this generalisation jump can be attributed to the slow norm reduction induced by weight decay, as explained by our analysis. We validate this mechanism empirically on several synthetic regression tasks.

Paper Structure

This paper contains 48 sections, 21 theorems, 96 equations, 4 figures.

Key Result

Theorem 1

As the weight decay parameter $\lambda$ is taken to $0$, the trajectory $w^\lambda(t)$ can be seen as a composition of two coupled dynamics:

Figures (4)

  • Figure 1: Gradient flow with small weight decay $\lambda$. (Left) A typical example of grokking: the training loss rapidly drops to zero, while the test loss plateaus for a long period before eventually decreasing—coinciding with a steady drop in the $\ell_2$-norm of the weights. (Right) Schematic illustration in parameter space $\mathbb{R}^d$ of the optimisation behaviour described in \ref{['thm:informal']}. The trajectory $w^\lambda(t)$ initially follows the unregularised gradient flow and converges to a manifold of critical points of $F$ (fast dynamics). At time $t \approx 1/\lambda$, the regularisation term becomes dominant and induces a slow drift along this manifold toward a lower $\ell_2$-norm solution (slow dynamics).
  • Figure 2: Low-rank matrix completion. (Left): Grokking phenomenon: the training loss drops quickly to zero, while the test loss remains high for an extended period before eventually improving—coinciding with a decrease in the norm of the weights $\Vert w \Vert^2 = \Vert U \Vert_F^2 + \Vert V \Vert_F^2$. (Right): Singular values of $UV^\top$ over time. Each line corresponds to the $i$-th singular value of $UV^\top$. The singular values rapidly converge to large positive values at time $t \approx 1$. However, as grokking starts around time $t \approx 10^2$, all but three begin decay towards zero. The remaining three approach the true singular values $\sigma^\star_1$, $\sigma^\star_2$, and $\sigma^\star_3$.
  • Figure 3: Two-layer ReLU network trained with gradient descent and small weight decay. (Left): Grokking phenomenon: the training loss drops quickly to zero, while the test loss remains high for an extended period before eventually improving—coinciding with a slow, steady decrease in the weight norm. (Right): Snapshots of the network's prediction function at various training times. The ground truth teacher function (a sum of three ReLUs) is shown in dotted light blue, and the training samples are shown as black crosses.
  • Figure 4: Gradient flow with small weight decay $\lambda$ on a two-layer diagonal linear network. Regression dataset. (Left): Empirical observation of the grokking behaviour. The training loss rapidly drops to zero, while the test loss remains flat for an extended period before eventually decreasing. This transition coincides with a slow but steady decrease in the $\ell_2$-norm of the weights. (Three plots on the right): Visualisation of the model predictions throughout training. The dotted light blue curve represents the teacher function, and the crosses indicate the training data. Snapshots of the model’s prediction function at various training times (shown in increasing colour intensity) illustrate how generalisation is affected before and after the transition at $t \approx 1 / \lambda$.

Theorems & Definitions (36)

  • Theorem 1: Main result, informal
  • Definition : Definition of the manifold $\mathcal{M}$
  • Proposition 1
  • Lemma 1
  • Proposition 2
  • Proposition 3
  • Proposition 4
  • Proposition 4
  • proof
  • Lemma 2
  • ...and 26 more