Beyond Progress Measures: Theoretical Insights into the Mechanism of Grokking
Zihan Gu, Ruoyu Chen, Hua Zhang, Yue Hu, Xiaochun Cao
TL;DR
The paper addresses the mechanism behind grokking by arguing that embedding uniformity induced by weight decay, together with the training-data distribution, jointly governs delayed generalization across architectures including Transformers and ResNet. It provides a rigorous two-part analysis: a formal prime-field task setup with an operator-based Transformer, and proofs that uniform embeddings are necessary and sufficient for the weight-decay term to reach a local minimum; it also introduces the main embedding diff (MED) as a concise progress measure and demonstrates grokking on a ResNet-18 task. The key contributions include a unified theoretical framework linking weight-decay driven uniformity to test accuracy, a formalization of dataset-driven upper bounds on Grokking, and an empirical validation via MED monitoring and a ResNet-18 grokking dataset. The practical impact lies in offering a principled lens to design datasets and architectures to manage generalization dynamics, and in providing a usable diagnostic (MED) to trace test-loss trends during training.
Abstract
Grokking, referring to the abrupt improvement in test accuracy after extended overfitting, offers valuable insights into the mechanisms of model generalization. Existing researches based on progress measures imply that grokking relies on understanding the optimization dynamics when the loss function is dominated solely by the weight decay term. However, we find that this optimization merely leads to token uniformity, which is not a sufficient condition for grokking. In this work, we investigate the grokking mechanism underlying the Transformer in the task of prime number operations. Based on theoretical analysis and experimental validation, we present the following insights: (i) The weight decay term encourages uniformity across all tokens in the embedding space when it is minimized. (ii) The occurrence of grokking is jointly determined by the uniformity of the embedding space and the distribution of the training dataset. Building on these insights, we provide a unified perspective for understanding various previously proposed progress measures and introduce a novel, concise, and effective progress measure that could trace the changes in test loss more accurately. Finally, to demonstrate the versatility of our theoretical framework, we design a dedicated dataset to validate our theory on ResNet-18, successfully showcasing the occurrence of grokking. The code is released at https://github.com/Qihuai27/Grokking-Insight.
