Table of Contents
Fetching ...

Why Do You Grok? A Theoretical Analysis of Grokking Modular Addition

Mohamad Amin Mohamadi, Zhiyuan Li, Lei Wu, Danica J. Sutherland

TL;DR

The paper investigates why grokking occurs in modular addition by proving a fundamental kernel-to-rich regime transition during gradient descent. It shows that permutation-equivariant kernel methods cannot generalize unless the training set covers a constant fraction of all possible data points, while regularized two-layer quadratic networks can generalize from far fewer samples once they leave the kernel regime. The authors establish both lower bounds and constructive upper bounds: lower bounds for kernel-based generalization in regression and classification, and rich-regime generalization guarantees with small $\ell_\infty$ norm (and margin-based PAC-Bayes bounds) that enable generalization with $\tilde{\mathcal{O}}(p^2)$ data for regression and $\tilde{\mathcal{O}}(p^{5/3})$ for classification. They provide theoretical results, a general framework for population loss lower bounds, and empirical evidence including Transformer-like models, supporting the grokking narrative as a delayed transition from kernel-dominated behavior to feature-learning dynamics. This work deepens the understanding of grokking and suggests practical regularization-based mechanisms to induce early generalization in overparameterized models.

Abstract

We present a theoretical explanation of the ``grokking'' phenomenon, where a model generalizes long after overfitting,for the originally-studied problem of modular addition. First, we show that early in gradient descent, when the ``kernel regime'' approximately holds, no permutation-equivariant model can achieve small population error on modular addition unless it sees at least a constant fraction of all possible data points. Eventually, however, models escape the kernel regime. We show that two-layer quadratic networks that achieve zero training loss with bounded $\ell_{\infty}$ norm generalize well with substantially fewer training points, and further show such networks exist and can be found by gradient descent with small $\ell_{\infty}$ regularization. We further provide empirical evidence that these networks as well as simple Transformers, leave the kernel regime only after initially overfitting. Taken together, our results strongly support the case for grokking as a consequence of the transition from kernel-like behavior to limiting behavior of gradient descent on deep networks.

Why Do You Grok? A Theoretical Analysis of Grokking Modular Addition

TL;DR

The paper investigates why grokking occurs in modular addition by proving a fundamental kernel-to-rich regime transition during gradient descent. It shows that permutation-equivariant kernel methods cannot generalize unless the training set covers a constant fraction of all possible data points, while regularized two-layer quadratic networks can generalize from far fewer samples once they leave the kernel regime. The authors establish both lower bounds and constructive upper bounds: lower bounds for kernel-based generalization in regression and classification, and rich-regime generalization guarantees with small norm (and margin-based PAC-Bayes bounds) that enable generalization with data for regression and for classification. They provide theoretical results, a general framework for population loss lower bounds, and empirical evidence including Transformer-like models, supporting the grokking narrative as a delayed transition from kernel-dominated behavior to feature-learning dynamics. This work deepens the understanding of grokking and suggests practical regularization-based mechanisms to induce early generalization in overparameterized models.

Abstract

We present a theoretical explanation of the ``grokking'' phenomenon, where a model generalizes long after overfitting,for the originally-studied problem of modular addition. First, we show that early in gradient descent, when the ``kernel regime'' approximately holds, no permutation-equivariant model can achieve small population error on modular addition unless it sees at least a constant fraction of all possible data points. Eventually, however, models escape the kernel regime. We show that two-layer quadratic networks that achieve zero training loss with bounded norm generalize well with substantially fewer training points, and further show such networks exist and can be found by gradient descent with small regularization. We further provide empirical evidence that these networks as well as simple Transformers, leave the kernel regime only after initially overfitting. Taken together, our results strongly support the case for grokking as a consequence of the transition from kernel-like behavior to limiting behavior of gradient descent on deep networks.
Paper Structure (36 sections, 58 theorems, 152 equations, 5 figures)

This paper contains 36 sections, 58 theorems, 152 equations, 5 figures.

Key Result

Theorem 2.4

For any p.s.d. kernel $K:{\mathcal{X}}\times{\mathcal{X}}\to \mathbb{R}$ and transformation group ${\mathcal{G}}_gX$, kernel regression (defi:kernel_methods) with respect to kernel $K$ is ${\mathcal{G}}_{{\mathcal{X}}}$-equivariant if and only if kernel $K$ is equivariant to ${\mathcal{G}}_gX$, i.e.

Figures (5)

  • Figure 1: Empirical investigation into grokking modular addition on two-layer networks in the classification task with cross-entropy loss. Left: Change of empirical NTK ($\lVert \hat{\Theta}_t - \hat{\Theta}_0 \rVert_F$) is negligible before fitting the training data. NTK changes drastically after overfitting, implying that the delayed generalization might be caused by delayed a transitioning from kernel to rich regime. Middle:Reducing initialization scale can mitigate grokking, to the point of completely eliminating the gap between train and test curves. $\alpha$ denotes scale multiplied by $\theta_0$, the initial weights according to default PyTorch initialization he2015delving. The dashed lines indicate train set statistics, and the solid lines correspond to the test set. Right: Empirical evaluations support a sample complexity of $\tilde{\mathop{\mathrm{\mathcal{O}}}\nolimits}(p^{5/3})$ on the classification task with cross-entropy loss. More details in \ref{['sec:classification']}.
  • Figure 2: Empirical evidence for kernel regime in early training. Left: Train (dashed) and test (solid) accuracy while training in the regression setting with various initialization scales $\alpha$ and a fixed $p=47$. Shrinking the scale of initialization can mitigate grokking in the regression task, eventually eliminating of the gap between train and test accuracies (at the cost of slower improvement in each). Right: eNTK continues to significantly change after overfitting.
  • Figure 3: Empirical verification of our theoretical explanation for generalization. We train a network of width $h=4p$ with gradient descent on $\ell_2$ loss and $10^{-4}$$\ell_\infty$-regularization on $2 \times p^{2.25}$ training samples (out of $p^3$) in the regression setting. Left:Generalization happens when the number of samples are more than $\Omega(p^2)$, as predicted by \ref{['th:generalization_bound_3d']}. The dashed and solid lines indicate train and test set statistics respectively. Right:$\ell_\infty$ norm of parameters after grokking remains the same for different problem dimensions $p$, as predicted by
  • Figure 4: Empirical investigation of grokking in the classification setting with tiny $\ell_\infty$ regularization on different problem dimensions $p$. Networks are trained with normalized gradient descent on cross-entropy loss and an $\ell_\infty$ regularization of $10^{-20}$ strength. The dashed lines in the indicate train set statistics, and the solid lines correspond to the test set.
  • Figure 5: Grokking in transformers happens after a delayed transition from kernel to rich regime. A one-layer transformer is trained with gradient descent using cross-entropy loss and a tiny $\ell_\infty$ regularization of $10^{-20}$ strength on $2 \times p^{5/3}$ training samples from the modular addition problem with various $p$s. Change of eNTK up to the point of fitting the training set is negligible. The eNTK has a drastic change only after fitting the whole training set, implying minimal feature learning until past overfitting. The dashed lines in the middle and left figures indicate train set statistics, and the solid lines correspond to the test set.

Theorems & Definitions (111)

  • Definition 2.1
  • Definition 2.2: Equivariant Algorithms
  • Definition 2.3: Kernel Methods
  • Theorem 2.4
  • proof : Proof of \ref{['thm:kernel_methods_equivariance']}
  • Theorem 3.1
  • proof : Proof sketch of \ref{['thm:mainbody-ntk']}
  • Definition 3.2
  • Theorem 3.3: Permuation Equivariance in Regression
  • Theorem 3.4: Lower Bound
  • ...and 101 more