Table of Contents
Fetching ...

Let Me Grok for You: Accelerating Grokking via Embedding Transfer from a Weaker Model

Zhiwei Xu, Zhiyu Ni, Yixin Wang, Wei Hu

TL;DR

Grokking causes delayed generalization where models memorize training data before achieving near-perfect test performance. The authors propose GrokTransfer, a two-stage embedding-transfer method that first trains a weaker model to learn an informative embedding, then initializes the target model’s embedding as a product $E_T=A B$ with $A=E_W$ and a trainable $B$, effectively enforcing a low-rank embedding. They prove that for a high-dimensional XOR task GrokTransfer enables direct generalization after transfer and demonstrate consistent empirical gains across fully connected nets and Transformers on modular addition, modular multiplication, and parity tasks. This approach reshapes training dynamics to reduce computation time and unpredictability in grokking without requiring extra data, suggesting practical benefits for accelerating generalization in diverse architectures.

Abstract

''Grokking'' is a phenomenon where a neural network first memorizes training data and generalizes poorly, but then suddenly transitions to near-perfect generalization after prolonged training. While intriguing, this delayed generalization phenomenon compromises predictability and efficiency. Ideally, models should generalize directly without delay. To this end, this paper proposes GrokTransfer, a simple and principled method for accelerating grokking in training neural networks, based on the key observation that data embedding plays a crucial role in determining whether generalization is delayed. GrokTransfer first trains a smaller, weaker model to reach a nontrivial (but far from optimal) test performance. Then, the learned input embedding from this weaker model is extracted and used to initialize the embedding in the target, stronger model. We rigorously prove that, on a synthetic XOR task where delayed generalization always occurs in normal training, GrokTransfer enables the target model to generalize directly without delay. Moreover, we demonstrate that, across empirical studies of different tasks, GrokTransfer effectively reshapes the training dynamics and eliminates delayed generalization, for both fully-connected neural networks and Transformers.

Let Me Grok for You: Accelerating Grokking via Embedding Transfer from a Weaker Model

TL;DR

Grokking causes delayed generalization where models memorize training data before achieving near-perfect test performance. The authors propose GrokTransfer, a two-stage embedding-transfer method that first trains a weaker model to learn an informative embedding, then initializes the target model’s embedding as a product with and a trainable , effectively enforcing a low-rank embedding. They prove that for a high-dimensional XOR task GrokTransfer enables direct generalization after transfer and demonstrate consistent empirical gains across fully connected nets and Transformers on modular addition, modular multiplication, and parity tasks. This approach reshapes training dynamics to reduce computation time and unpredictability in grokking without requiring extra data, suggesting practical benefits for accelerating generalization in diverse architectures.

Abstract

''Grokking'' is a phenomenon where a neural network first memorizes training data and generalizes poorly, but then suddenly transitions to near-perfect generalization after prolonged training. While intriguing, this delayed generalization phenomenon compromises predictability and efficiency. Ideally, models should generalize directly without delay. To this end, this paper proposes GrokTransfer, a simple and principled method for accelerating grokking in training neural networks, based on the key observation that data embedding plays a crucial role in determining whether generalization is delayed. GrokTransfer first trains a smaller, weaker model to reach a nontrivial (but far from optimal) test performance. Then, the learned input embedding from this weaker model is extracted and used to initialize the embedding in the target, stronger model. We rigorously prove that, on a synthetic XOR task where delayed generalization always occurs in normal training, GrokTransfer enables the target model to generalize directly without delay. Moreover, we demonstrate that, across empirical studies of different tasks, GrokTransfer effectively reshapes the training dynamics and eliminates delayed generalization, for both fully-connected neural networks and Transformers.

Paper Structure

This paper contains 29 sections, 8 theorems, 92 equations, 18 figures, 1 table.

Key Result

Lemma 3.0

For any $f(x) = \sum_{j=1}^3 a_j \phi(w_j^{\top}x)$, where $\phi$ is the ReLU activation function, we have

Figures (18)

  • Figure 1: (a) Overview of the GrokTransfer framework. (b) Comparison of the training dynamics of a model trained using GrokTransfer versus one trained from scratch. There is a clear phase transition between memorization and generalization if we train the model from scratch (blue lines). GrokTransfer (red lines) enables the model to make continuous progress, significantly reducing the gap between memorization and generalization. See Appendix \ref{['sec:expr-details']} for the detailed experimental setup.
  • Figure 2: FNN training dynamics using different embeddings for the modular addition task ($p=113$). The training dynamics vary significantly across different embeddings. The one-hot embedding and GPT embedding exhibit sharp phase transition. See Appendix \ref{['sec:expr-details']} for details of the experimental setup.
  • Figure 3: Change of empirical NTK.
  • Figure 4: (a) Training dynamics of a two-layer neural network with a hidden width of $2048$, where grokking is observed. (b) Training dynamics of a two-layer neural network with a hidden width of $3$. The model can only achieve around $75\%$ validation accuracy and a phase transition near $100$th epoch is observed. (c) Visualization of individual neuron weights from the model trained in (b). It shows three distinct patterns and each corresponds to a feature direction of the XOR data distribution. See Appendix \ref{['sec:expr-details']} for details of the experimental setup.
  • Figure 5: (a) 3-D Visualization of the distribution $P$ with the embedding from the weak model. The clusters are well-separated under the new embedding. (b) Norm ratio $r_W$ for different values of $p$ and $\varepsilon$ with fixed sample size $n$, indicating that $r_W$ does not depend on $p$. (c) Norm ratio $r_W$ for different values of $n$ and $\varepsilon$ with fixed feature dimension $p$. For each $\epsilon$, the slope is around $-1/2$, indicating that $r_W$ is proportional to $1/\sqrt{n}$. See Appendix \ref{['sec:expr-details']} for details of the experimental setup.
  • ...and 13 more figures

Theorems & Definitions (8)

  • Lemma 3.0
  • Theorem 3.1
  • Lemma A.0
  • Theorem A.1
  • Lemma A.1
  • Lemma A.2
  • Lemma A.3
  • Lemma A.4