Table of Contents
Fetching ...

Dichotomy of Early and Late Phase Implicit Biases Can Provably Induce Grokking

Kaifeng Lyu, Jikai Jin, Zhiyuan Li, Simon S. Du, Jason D. Lee, Wei Hu

TL;DR

The paper investigates grokking, the abrupt generalization improvement observed when training neural nets on tasks such as modular addition. It develops a theory for homogeneous networks trained with large initialization and small weight decay, revealing a two-phase dynamic: an early kernel regime governed by the Neural Tangent Kernel (NTK) leading to a kernel predictor, and a late regime where gradient flow converges to max-margin/min-norm predictors, producing a sharp generalization transition. The authors prove precise time-scales and convergence directions for both classification and regression, and provide concrete instantiations (diagonal linear nets for sparse linear classification and an over-parameterized matrix completion model) demonstrating grokking and misgrokking. Their results illuminate how implicit biases at different training phases can generate dramatic changes in test performance and offer guidance for understanding or mitigating grokking in practice. Key insights connect to NTK theory, margin maximization, and low-rank matrix recovery, with implications for the design of initialization, regularization, and training schedules in deep networks.

Abstract

Recent work by Power et al. (2022) highlighted a surprising "grokking" phenomenon in learning arithmetic tasks: a neural net first "memorizes" the training set, resulting in perfect training accuracy but near-random test accuracy, and after training for sufficiently longer, it suddenly transitions to perfect test accuracy. This paper studies the grokking phenomenon in theoretical setups and shows that it can be induced by a dichotomy of early and late phase implicit biases. Specifically, when training homogeneous neural nets with large initialization and small weight decay on both classification and regression tasks, we prove that the training process gets trapped at a solution corresponding to a kernel predictor for a long time, and then a very sharp transition to min-norm/max-margin predictors occurs, leading to a dramatic change in test accuracy.

Dichotomy of Early and Late Phase Implicit Biases Can Provably Induce Grokking

TL;DR

The paper investigates grokking, the abrupt generalization improvement observed when training neural nets on tasks such as modular addition. It develops a theory for homogeneous networks trained with large initialization and small weight decay, revealing a two-phase dynamic: an early kernel regime governed by the Neural Tangent Kernel (NTK) leading to a kernel predictor, and a late regime where gradient flow converges to max-margin/min-norm predictors, producing a sharp generalization transition. The authors prove precise time-scales and convergence directions for both classification and regression, and provide concrete instantiations (diagonal linear nets for sparse linear classification and an over-parameterized matrix completion model) demonstrating grokking and misgrokking. Their results illuminate how implicit biases at different training phases can generate dramatic changes in test performance and offer guidance for understanding or mitigating grokking in practice. Key insights connect to NTK theory, margin maximization, and low-rank matrix recovery, with implications for the design of initialization, regularization, and training schedules in deep networks.

Abstract

Recent work by Power et al. (2022) highlighted a surprising "grokking" phenomenon in learning arithmetic tasks: a neural net first "memorizes" the training set, resulting in perfect training accuracy but near-random test accuracy, and after training for sufficiently longer, it suddenly transitions to perfect test accuracy. This paper studies the grokking phenomenon in theoretical setups and shows that it can be induced by a dichotomy of early and late phase implicit biases. Specifically, when training homogeneous neural nets with large initialization and small weight decay on both classification and regression tasks, we prove that the training process gets trapped at a solution corresponding to a kernel predictor for a long time, and then a very sharp transition to min-norm/max-margin predictors occurs, leading to a dramatic change in test accuracy.
Paper Structure (28 sections, 36 theorems, 134 equations, 8 figures)

This paper contains 28 sections, 36 theorems, 134 equations, 8 figures.

Key Result

Theorem 3.4

For any all constants $c \in (0, 1)$, letting $T^{-}_{\mathrm{c}}(\alpha) := \frac{1-c}{\lambda} \log \alpha$, it holds that where $Z(\alpha) := \frac{1}{\gamma_{\mathrm{ntk}}} \log \frac{\alpha^c}{\lambda}$ is a normalizing factor.

Figures (8)

  • Figure 1: Grokking
  • Figure 2: "Misgrokking"
  • Figure 4: He init, $\lambda = 10^{-4}$
  • Figure 5: Effect of Initialization Scale
  • Figure 6: Effect of Weight Decay
  • ...and 3 more figures

Theorems & Definitions (38)

  • Theorem 3.4
  • Theorem 3.5
  • Corollary 3.6: Diagonal linear nets, kernel regime
  • Corollary 3.7: Diagonal linear nets, rich regime
  • Theorem 3.9
  • Theorem 3.10
  • Corollary 3.11: Matrix completion, kernel regime
  • Theorem 3.12: Matrix completion, rich regime
  • Remark 3.13
  • Theorem B.1
  • ...and 28 more