Table of Contents
Fetching ...

Grokking as the Transition from Lazy to Rich Training Dynamics

Tanishq Kumar, Blake Bordelon, Samuel J. Gershman, Cengiz Pehlevan

TL;DR

This work argues that grokking—the delay between reductions in train and test loss—stems from a transition between lazy (kernel-like) training and later rich feature learning. Using a minimal two-layer perceptron for polynomial regression, the authors show that grokking can be driven by the output scale $\alpha$ and the initial kernel-task alignment $\varepsilon$, with larger $\alpha$ or misaligned tasks amplifying the delay while enabling eventual generalization through feature learning. They decompose generalization error into variance, misalignment, and linear-power terms, and connect these to the NTK spectrum and a related kernel-alignment measure (CKA). The results generalize beyond the toy model to MNIST, one-layer transformers, and student–teacher networks, and emphasize loss curves as the reliable diagnostic over accuracy curves. Overall, the work provides a unified lazy-to-rich framework for grokking, showing weight decay is neither necessary nor sufficient and that data regime and alignment crucially shape learning dynamics and generalization timing.

Abstract

We propose that the grokking phenomenon, where the train loss of a neural network decreases much earlier than its test loss, can arise due to a neural network transitioning from lazy training dynamics to a rich, feature learning regime. To illustrate this mechanism, we study the simple setting of vanilla gradient descent on a polynomial regression problem with a two layer neural network which exhibits grokking without regularization in a way that cannot be explained by existing theories. We identify sufficient statistics for the test loss of such a network, and tracking these over training reveals that grokking arises in this setting when the network first attempts to fit a kernel regression solution with its initial features, followed by late-time feature learning where a generalizing solution is identified after train loss is already low. We find that the key determinants of grokking are the rate of feature learning -- which can be controlled precisely by parameters that scale the network output -- and the alignment of the initial features with the target function $y(x)$. We argue this delayed generalization arises when (1) the top eigenvectors of the initial neural tangent kernel and the task labels $y(x)$ are misaligned, but (2) the dataset size is large enough so that it is possible for the network to generalize eventually, but not so large that train loss perfectly tracks test loss at all epochs, and (3) the network begins training in the lazy regime so does not learn features immediately. We conclude with evidence that this transition from lazy (linear model) to rich training (feature learning) can control grokking in more general settings, like on MNIST, one-layer Transformers, and student-teacher networks.

Grokking as the Transition from Lazy to Rich Training Dynamics

TL;DR

This work argues that grokking—the delay between reductions in train and test loss—stems from a transition between lazy (kernel-like) training and later rich feature learning. Using a minimal two-layer perceptron for polynomial regression, the authors show that grokking can be driven by the output scale and the initial kernel-task alignment , with larger or misaligned tasks amplifying the delay while enabling eventual generalization through feature learning. They decompose generalization error into variance, misalignment, and linear-power terms, and connect these to the NTK spectrum and a related kernel-alignment measure (CKA). The results generalize beyond the toy model to MNIST, one-layer transformers, and student–teacher networks, and emphasize loss curves as the reliable diagnostic over accuracy curves. Overall, the work provides a unified lazy-to-rich framework for grokking, showing weight decay is neither necessary nor sufficient and that data regime and alignment crucially shape learning dynamics and generalization timing.

Abstract

We propose that the grokking phenomenon, where the train loss of a neural network decreases much earlier than its test loss, can arise due to a neural network transitioning from lazy training dynamics to a rich, feature learning regime. To illustrate this mechanism, we study the simple setting of vanilla gradient descent on a polynomial regression problem with a two layer neural network which exhibits grokking without regularization in a way that cannot be explained by existing theories. We identify sufficient statistics for the test loss of such a network, and tracking these over training reveals that grokking arises in this setting when the network first attempts to fit a kernel regression solution with its initial features, followed by late-time feature learning where a generalizing solution is identified after train loss is already low. We find that the key determinants of grokking are the rate of feature learning -- which can be controlled precisely by parameters that scale the network output -- and the alignment of the initial features with the target function . We argue this delayed generalization arises when (1) the top eigenvectors of the initial neural tangent kernel and the task labels are misaligned, but (2) the dataset size is large enough so that it is possible for the network to generalize eventually, but not so large that train loss perfectly tracks test loss at all epochs, and (3) the network begins training in the lazy regime so does not learn features immediately. We conclude with evidence that this transition from lazy (linear model) to rich training (feature learning) can control grokking in more general settings, like on MNIST, one-layer Transformers, and student-teacher networks.
Paper Structure (38 sections, 23 equations, 19 figures)

This paper contains 38 sections, 23 equations, 19 figures.

Figures (19)

  • Figure 1: (a) Claimed parameter dynamics during grokking in a parameter space of $\mathbb{R}^3$ for illustrative purposes. $S(X, w_0)$ is an affine subspace of parameter space reachable by models linearized around $w_0$. (b) Grokking on a polynomial regression task introduced in Section \ref{['polyreg']} with an MLP, vanilla GD, and zero regularization. Green and blue loss curves in (b) correspond to sketched green and blue parameter dynamics in (a). Horizontal line (black) in (b) is the mean-squared error of best kernel regression estimate with the NTK at initialization. (c) The fact that parameter weight norm increases (dashed orange) cannot be explained by any existing theories of grokking. Features get aligned (grey) in a way we make precise in Section \ref{['polyreg']}.
  • Figure 2: (a) and (b) demonstrate accuracy and loss curves for grokking on a modular arithmetic task that shows an increase in parameter weight norm during training. (c) is a sweep over the laziness parameter $\alpha$ that we will introduce and provide a theory for, showcasing how it can continuously control grokking.
  • Figure 3: Lazy training ($\alpha$) and kernel-task misalignment ($\epsilon$) alter the grokking learning curves in distinct ways. (Top) Learning curves that show grokking, and (Bottom) corresponding parameter dynamics during learning. (a) At fixed $\epsilon$, the laziness parameter $\alpha$ controls the timescale of the delay in grokking. At small $\alpha$, the grokking effect disappears as the generalizing features are extracted immediately. At large $\alpha$, the model approaches a linearized model. The final test loss decreases with $\alpha$ as we allow the network to learn more features. (b) The task alignment to the initial kernel, measured by $\epsilon$, determines how much the loss falls when the network initially uses its linearized solution. Smaller $\epsilon$ increases the amount of feature learning during training because the initial kernel does worse on the task, so feature learning is necessary. Thus lower alignment can result in better generalization. (c)-(d) Illustrations of the dynamics at varying $\alpha,\epsilon$. In (d), each plane represents a different affine space spanned by the initial gradients, which are a function of $\epsilon$. (e) Time to grok (time delay between train loss fall and test loss fall) as a function of $\alpha,\epsilon$, showing how lazy, misaligned networks grok the most intensely.
  • Figure 4: (Left) Learning curves for a 2 layer multi-layer perceptron that groks on our polynomial regression task (no weight decay, vanilla GD). (Right) Theoretical decomposition showing how the initial rise in test loss comes from the network putting power in the linear component from its inital NTK before beginning to align with the task features around 1-2k epochs, resulting in a delayed fall in the test loss. This delay is precisely grokking.
  • Figure 5: Demonstrating grokking on standard Gaussian $X \in \mathbb{R}^{D \times P}$ data with the one-hidden layer architecture we had for polynomial regression. Label vectors, $y$ are replaced with $j$-th largest eigenvectors of the initial NTK (ordered by descending eigenvalue) for $j \in \{1, 70, 100\}$, respectively for (a)-(c). If the network begins highly aligned as in (a), learning curves move together. If the network features are poorly aligned at initialization, as in (c), the network cannot generalize. In the middle, as in (b), the network groks.
  • ...and 14 more figures