Table of Contents
Fetching ...

To grok or not to grok: Disentangling generalization and memorization on corrupted algorithmic datasets

Darshil Doshi, Aritra Das, Tianyu He, Andrey Gromov

TL;DR

The paper studies how neural networks distinguish memorization from genuine generalization when training data carry corrupted labels. By using analytically tractable models on modular arithmetic tasks, it shows that networks can memorize corrupted examples yet generalize nearly perfectly, and that memorizing neurons can be identified and pruned to restore perfect generalization on uncorrupted data. Regularization techniques such as weight decay, dropout, and BatchNorm shift the learned representations toward generalizing features, with weight decay and dropout transforming memorizing neurons into generalizing ones, while BatchNorm de-amplifies memorizing signals. The work reveals a two-stage training dynamic—grokking to high train/test accuracy, followed by unlearning memorization—providing mechanistic insight into robust generalization and practical pruning strategies across architectures.

Abstract

Robust generalization is a major challenge in deep learning, particularly when the number of trainable parameters is very large. In general, it is very difficult to know if the network has memorized a particular set of examples or understood the underlying rule (or both). Motivated by this challenge, we study an interpretable model where generalizing representations are understood analytically, and are easily distinguishable from the memorizing ones. Namely, we consider multi-layer perceptron (MLP) and Transformer architectures trained on modular arithmetic tasks, where ($ξ\cdot 100\%$) of labels are corrupted (\emph{i.e.} some results of the modular operations in the training set are incorrect). We show that (i) it is possible for the network to memorize the corrupted labels \emph{and} achieve $100\%$ generalization at the same time; (ii) the memorizing neurons can be identified and pruned, lowering the accuracy on corrupted data and improving the accuracy on uncorrupted data; (iii) regularization methods such as weight decay, dropout and BatchNorm force the network to ignore the corrupted data during optimization, and achieve $100\%$ accuracy on the uncorrupted dataset; and (iv) the effect of these regularization methods is (``mechanistically'') interpretable: weight decay and dropout force all the neurons to learn generalizing representations, while BatchNorm de-amplifies the output of memorizing neurons and amplifies the output of the generalizing ones. Finally, we show that in the presence of regularization, the training dynamics involves two consecutive stages: first, the network undergoes \emph{grokking} dynamics reaching high train \emph{and} test accuracy; second, it unlearns the memorizing representations, where the train accuracy suddenly jumps from $100\%$ to $100 (1-ξ)\%$.

To grok or not to grok: Disentangling generalization and memorization on corrupted algorithmic datasets

TL;DR

The paper studies how neural networks distinguish memorization from genuine generalization when training data carry corrupted labels. By using analytically tractable models on modular arithmetic tasks, it shows that networks can memorize corrupted examples yet generalize nearly perfectly, and that memorizing neurons can be identified and pruned to restore perfect generalization on uncorrupted data. Regularization techniques such as weight decay, dropout, and BatchNorm shift the learned representations toward generalizing features, with weight decay and dropout transforming memorizing neurons into generalizing ones, while BatchNorm de-amplifies memorizing signals. The work reveals a two-stage training dynamic—grokking to high train/test accuracy, followed by unlearning memorization—providing mechanistic insight into robust generalization and practical pruning strategies across architectures.

Abstract

Robust generalization is a major challenge in deep learning, particularly when the number of trainable parameters is very large. In general, it is very difficult to know if the network has memorized a particular set of examples or understood the underlying rule (or both). Motivated by this challenge, we study an interpretable model where generalizing representations are understood analytically, and are easily distinguishable from the memorizing ones. Namely, we consider multi-layer perceptron (MLP) and Transformer architectures trained on modular arithmetic tasks, where () of labels are corrupted (\emph{i.e.} some results of the modular operations in the training set are incorrect). We show that (i) it is possible for the network to memorize the corrupted labels \emph{and} achieve generalization at the same time; (ii) the memorizing neurons can be identified and pruned, lowering the accuracy on corrupted data and improving the accuracy on uncorrupted data; (iii) regularization methods such as weight decay, dropout and BatchNorm force the network to ignore the corrupted data during optimization, and achieve accuracy on the uncorrupted dataset; and (iv) the effect of these regularization methods is (``mechanistically'') interpretable: weight decay and dropout force all the neurons to learn generalizing representations, while BatchNorm de-amplifies the output of memorizing neurons and amplifies the output of the generalizing ones. Finally, we show that in the presence of regularization, the training dynamics involves two consecutive stages: first, the network undergoes \emph{grokking} dynamics reaching high train \emph{and} test accuracy; second, it unlearns the memorizing representations, where the train accuracy suddenly jumps from to .
Paper Structure (49 sections, 9 equations, 32 figures, 2 tables)

This paper contains 49 sections, 9 equations, 32 figures, 2 tables.

Figures (32)

  • Figure 1: Training curves in various phases. All plots are made for networks trained with data-fraction $\alpha=0.5$ and corruption-fraction $\xi=0.35$. (a) No regularization: Coexistence of generalization and memorization -- both train and test accuracies are high. (b)(c) Adding weight decay: The network generalizes on the test data but does not memorize the corrupted training data, resulting in a negative generalization gap! Remarkably, the network predicts the "true" labels for the corrupted examples. We term these phases Partial Inversion and Full Inversion, based on the degree of memorization of corrupted data. ("inversion" refers to test accuracy being higher than train accuracy.)
  • Figure 2: Grokking the modular arithmetic task over $\mathbb Z_{97}$ with 2-layer MLP, trained with AdamW. (a) Sharp transition in test accuracy, long after overfitting. $\overline{\text{IPR}} \coloneqq \mathbb E_k \left[ \text{IPR}_k \right]$ monotonically increases over time, indicating periodic representations. (b) Example row vector $(U_{k\cdot})$ before and after grokking. Generalization is achieved through periodic weights (equation \ref{['eq:periodic_weights']}). (c) Histogram of per-neuron IPRs before and after grokking -- the distribution shifts to high IPR.
  • Figure 3: Modular Addition phase diagrams with various regularization methods. A larger data fraction leads to more "correct" examples, leading to higher corruption-robustness. Increasing regularization, in the form of weight decay or dropout, enhances robustness to label corruption and facilitates better generalization.
  • Figure 4: Distribution of per-neuron IPR for trained networks in various phases. Coexistence phase has a bimodal IPR distribution, where the high and low IPR neurons facilitate generalization and memorization, respectively. Regularization with weight decay or Dropout shifts the IPR distribution towards higher values -- Generalizing neurons get more populous compared to memorizing ones; resulting in more robust generalization and Inversion. All plots are made for networks trained with data-fraction $\alpha=0.5$ and corruption-fraction $\xi=0.35$; with various regularization strengths.
  • Figure 5: IPR distributions of networks with/without BatchNorm; the correlation between BatchNorm weights and IPR. The IPR distributions in the two models do not exhibit significant differences. However, BatchNorm helps generalization by assigning higher weights ($\gamma_k$) to high IPR neurons compared to low IPR neurons. Both models are trained at data fraction $\alpha=0.65$ and corruption fraction $\xi=0.2$, with batch-size = 64.
  • ...and 27 more figures