Table of Contents
Fetching ...

Omnigrok: Grokking Beyond Algorithmic Data

Ziming Liu, Eric J. Michaud, Max Tegmark

TL;DR

This work analyzes grokking—the puzzling delay in generalization after overfitting—through neural loss landscapes, introducing the LU mechanism: training loss follows an L-shape while test loss follows a U-shape as a function of weight norm. By reducing the optimization to a one-dimensional landscape over weight norm and direction, the authors show how initialization scale and weight decay govern the time to generalize, including delays that scale as t ∝ 1/γ. They demonstrate grokking across diverse tasks (algorithmic data, MNIST, IMDb, QM9) and reveal that the strength of grokking correlates with the reliance on learning good representations. The study then argues that representation quality explains why grokking is dramatic for algorithmic datasets but muted for MNIST, and shows that constraining the weight norm can almost eliminate grokking. Overall, a loss-landscape perspective provides a coherent, predictive framework for grokking across domains and highlights representation learning as a central factor.

Abstract

Grokking, the unusual phenomenon for algorithmic datasets where generalization happens long after overfitting the training data, has remained elusive. We aim to understand grokking by analyzing the loss landscapes of neural networks, identifying the mismatch between training and test losses as the cause for grokking. We refer to this as the "LU mechanism" because training and test losses (against model weight norm) typically resemble "L" and "U", respectively. This simple mechanism can nicely explain many aspects of grokking: data size dependence, weight decay dependence, the emergence of representations, etc. Guided by the intuitive picture, we are able to induce grokking on tasks involving images, language and molecules. In the reverse direction, we are able to eliminate grokking for algorithmic datasets. We attribute the dramatic nature of grokking for algorithmic datasets to representation learning.

Omnigrok: Grokking Beyond Algorithmic Data

TL;DR

This work analyzes grokking—the puzzling delay in generalization after overfitting—through neural loss landscapes, introducing the LU mechanism: training loss follows an L-shape while test loss follows a U-shape as a function of weight norm. By reducing the optimization to a one-dimensional landscape over weight norm and direction, the authors show how initialization scale and weight decay govern the time to generalize, including delays that scale as t ∝ 1/γ. They demonstrate grokking across diverse tasks (algorithmic data, MNIST, IMDb, QM9) and reveal that the strength of grokking correlates with the reliance on learning good representations. The study then argues that representation quality explains why grokking is dramatic for algorithmic datasets but muted for MNIST, and shows that constraining the weight norm can almost eliminate grokking. Overall, a loss-landscape perspective provides a coherent, predictive framework for grokking across domains and highlights representation learning as a central factor.

Abstract

Grokking, the unusual phenomenon for algorithmic datasets where generalization happens long after overfitting the training data, has remained elusive. We aim to understand grokking by analyzing the loss landscapes of neural networks, identifying the mismatch between training and test losses as the cause for grokking. We refer to this as the "LU mechanism" because training and test losses (against model weight norm) typically resemble "L" and "U", respectively. This simple mechanism can nicely explain many aspects of grokking: data size dependence, weight decay dependence, the emergence of representations, etc. Guided by the intuitive picture, we are able to induce grokking on tasks involving images, language and molecules. In the reverse direction, we are able to eliminate grokking for algorithmic datasets. We attribute the dramatic nature of grokking for algorithmic datasets to representation learning.
Paper Structure (14 sections, 9 equations, 12 figures, 1 table)

This paper contains 14 sections, 9 equations, 12 figures, 1 table.

Figures (12)

  • Figure 1: (a) $w$: $L_2$ norm of model weights. Generalizing solutions (green stars) are concentrated around a sphere in the weight space where $w\approx w_c$ (green). Overfitting solutions (orange) populate the $w\gtrsim w_c$ region. (b) The training loss (orange) and test loss (gray) have the shape of L and U, respectively. Their mismatch in the $w>w_c$ region leads to fast-slow dynamics, resulting in grokking.
  • Figure 2: Teacher-student setup. $\alpha$: student initialization scale, $\gamma$: weight decay. (a) The reduced training loss and test loss have the shape of "L" and "U", respectively. (b) Top row: large initialization ($\alpha=2.0$) can demonstrate no generalization (no reg), grokking (small reg) and fast generalization (large reg). Bottom: small initialization ($\alpha=0.5$) always generalizes fast, regardless of weight deacy. (c) $\alpha=2$. The steps to overfitting is independent of weight decay, while the steps to generalization scale inversely with the weight decay.
  • Figure 3: MNIST. (a) reduced training error, (b) reduced test error. Comparing A and B: larger weight norm makes learning grok (delay generalization). Comparing B and C: a larger training data size makes learning de-grok (speed up generalization). (c) "LU" holds truer for smaller data. (d) Accuracy curves for MNIST in the setting where we observe grokking. (e) Time to generalize as a function of training set size $N$.
  • Figure 4: We use an LSTM to predict IMDb reviews. (a) training error; (b) test error; (c) reduced losses for data size 1k (top) and 50k (bottom); (d) With 1k data, a (weak) grokking signal is observed for large initializations ($\alpha=6$), while no grokking is observed for standard initializations ($\alpha=1$).
  • Figure 5: We use a GCNN to predict isotropic polarizability of molecules in the QM9 dataset. (a) training loss; (b) test loss; (c) reduced losses for data size 100 (top) and 3000 (bottom); (d) with 200 training samples, grokking is observed for large initialization ($\alpha=3$), while no grokking is observed for standard initializations ($\alpha=1$).
  • ...and 7 more figures