Table of Contents
Fetching ...

Beyond Progress Measures: Theoretical Insights into the Mechanism of Grokking

Zihan Gu, Ruoyu Chen, Hua Zhang, Yue Hu, Xiaochun Cao

TL;DR

The paper addresses the mechanism behind grokking by arguing that embedding uniformity induced by weight decay, together with the training-data distribution, jointly governs delayed generalization across architectures including Transformers and ResNet. It provides a rigorous two-part analysis: a formal prime-field task setup with an operator-based Transformer, and proofs that uniform embeddings are necessary and sufficient for the weight-decay term to reach a local minimum; it also introduces the main embedding diff (MED) as a concise progress measure and demonstrates grokking on a ResNet-18 task. The key contributions include a unified theoretical framework linking weight-decay driven uniformity to test accuracy, a formalization of dataset-driven upper bounds on Grokking, and an empirical validation via MED monitoring and a ResNet-18 grokking dataset. The practical impact lies in offering a principled lens to design datasets and architectures to manage generalization dynamics, and in providing a usable diagnostic (MED) to trace test-loss trends during training.

Abstract

Grokking, referring to the abrupt improvement in test accuracy after extended overfitting, offers valuable insights into the mechanisms of model generalization. Existing researches based on progress measures imply that grokking relies on understanding the optimization dynamics when the loss function is dominated solely by the weight decay term. However, we find that this optimization merely leads to token uniformity, which is not a sufficient condition for grokking. In this work, we investigate the grokking mechanism underlying the Transformer in the task of prime number operations. Based on theoretical analysis and experimental validation, we present the following insights: (i) The weight decay term encourages uniformity across all tokens in the embedding space when it is minimized. (ii) The occurrence of grokking is jointly determined by the uniformity of the embedding space and the distribution of the training dataset. Building on these insights, we provide a unified perspective for understanding various previously proposed progress measures and introduce a novel, concise, and effective progress measure that could trace the changes in test loss more accurately. Finally, to demonstrate the versatility of our theoretical framework, we design a dedicated dataset to validate our theory on ResNet-18, successfully showcasing the occurrence of grokking. The code is released at https://github.com/Qihuai27/Grokking-Insight.

Beyond Progress Measures: Theoretical Insights into the Mechanism of Grokking

TL;DR

The paper addresses the mechanism behind grokking by arguing that embedding uniformity induced by weight decay, together with the training-data distribution, jointly governs delayed generalization across architectures including Transformers and ResNet. It provides a rigorous two-part analysis: a formal prime-field task setup with an operator-based Transformer, and proofs that uniform embeddings are necessary and sufficient for the weight-decay term to reach a local minimum; it also introduces the main embedding diff (MED) as a concise progress measure and demonstrates grokking on a ResNet-18 task. The key contributions include a unified theoretical framework linking weight-decay driven uniformity to test accuracy, a formalization of dataset-driven upper bounds on Grokking, and an empirical validation via MED monitoring and a ResNet-18 grokking dataset. The practical impact lies in offering a principled lens to design datasets and architectures to manage generalization dynamics, and in providing a usable diagnostic (MED) to trace test-loss trends during training.

Abstract

Grokking, referring to the abrupt improvement in test accuracy after extended overfitting, offers valuable insights into the mechanisms of model generalization. Existing researches based on progress measures imply that grokking relies on understanding the optimization dynamics when the loss function is dominated solely by the weight decay term. However, we find that this optimization merely leads to token uniformity, which is not a sufficient condition for grokking. In this work, we investigate the grokking mechanism underlying the Transformer in the task of prime number operations. Based on theoretical analysis and experimental validation, we present the following insights: (i) The weight decay term encourages uniformity across all tokens in the embedding space when it is minimized. (ii) The occurrence of grokking is jointly determined by the uniformity of the embedding space and the distribution of the training dataset. Building on these insights, we provide a unified perspective for understanding various previously proposed progress measures and introduce a novel, concise, and effective progress measure that could trace the changes in test loss more accurately. Finally, to demonstrate the versatility of our theoretical framework, we design a dedicated dataset to validate our theory on ResNet-18, successfully showcasing the occurrence of grokking. The code is released at https://github.com/Qihuai27/Grokking-Insight.

Paper Structure

This paper contains 26 sections, 11 theorems, 73 equations, 20 figures, 2 tables.

Key Result

Lemma 3.2

Let $(\bar{\phi(i)},\bar{\phi(j)},\bar{\phi(\mathsf{cls})} )= \mathsf{MLP} \circ \mathsf{SelfAtt} (\phi(i),\phi(j),\phi(\mathsf{cls}))$ and $\bar{c_{i,j}} = l(\bar{\phi(\mathsf{cls})})$. If the model can give the correct output, there exists a set of vectors $\{c_i\},i=0,1,...,p-1$ and the relevant

Figures (20)

  • Figure 1: Grokking occurs from two factors: the uniformity of embeddings under weight decay and the dataset distribution. Embedding uniformity arises from parameter updates guided by minimizing the weight decay term, enabling us to track grokking by monitoring this uniformity. The distribution of the training set determines the upper limit of test accuracy. The above insights are applicable to diverse models, including DNN, Transformer and ResNet.
  • Figure 2: Impact of Using More Transformer Layers. As the number of transformer layers increases, the test accuracy fluctuates violently, while the med keeps decreasing.
  • Figure 3: Designed Training Set Distribution. The first picture is randomly selected, the second picture corresponds to the square area we dug out, and the third picture corresponds to the strip area we dug out.
  • Figure 4: Accuracy Curves When $f(i,j) = i^2 + ij + j^2$. The upper limit of accuracy increases as $\text{frac}$ increases. However, the number of samples that can be correctly output in the test set has not increased
  • Figure 5: Performance of MED under different prime numbers. To amplify the changes in the MED, we omitted the $\frac{1}{p}$ coefficient defined in Definition \ref{['med']} which means we did not take the average here. It can be observed in all four groups of experiments that the MED and test loss have consistent changing trends.
  • ...and 15 more figures

Theorems & Definitions (37)

  • Definition 3.1
  • Lemma 3.2
  • Remark 3.3
  • Lemma 3.4
  • Theorem 3.5
  • Remark 3.6
  • Remark 3.7
  • Theorem 3.8
  • Remark 3.9
  • Theorem 3.10
  • ...and 27 more