Dichotomy of Early and Late Phase Implicit Biases Can Provably Induce Grokking
Kaifeng Lyu, Jikai Jin, Zhiyuan Li, Simon S. Du, Jason D. Lee, Wei Hu
TL;DR
The paper investigates grokking, the abrupt generalization improvement observed when training neural nets on tasks such as modular addition. It develops a theory for homogeneous networks trained with large initialization and small weight decay, revealing a two-phase dynamic: an early kernel regime governed by the Neural Tangent Kernel (NTK) leading to a kernel predictor, and a late regime where gradient flow converges to max-margin/min-norm predictors, producing a sharp generalization transition. The authors prove precise time-scales and convergence directions for both classification and regression, and provide concrete instantiations (diagonal linear nets for sparse linear classification and an over-parameterized matrix completion model) demonstrating grokking and misgrokking. Their results illuminate how implicit biases at different training phases can generate dramatic changes in test performance and offer guidance for understanding or mitigating grokking in practice. Key insights connect to NTK theory, margin maximization, and low-rank matrix recovery, with implications for the design of initialization, regularization, and training schedules in deep networks.
Abstract
Recent work by Power et al. (2022) highlighted a surprising "grokking" phenomenon in learning arithmetic tasks: a neural net first "memorizes" the training set, resulting in perfect training accuracy but near-random test accuracy, and after training for sufficiently longer, it suddenly transitions to perfect test accuracy. This paper studies the grokking phenomenon in theoretical setups and shows that it can be induced by a dichotomy of early and late phase implicit biases. Specifically, when training homogeneous neural nets with large initialization and small weight decay on both classification and regression tasks, we prove that the training process gets trapped at a solution corresponding to a kernel predictor for a long time, and then a very sharp transition to min-norm/max-margin predictors occurs, leading to a dramatic change in test accuracy.
