Table of Contents
Fetching ...

Grokking in Linear Estimators -- A Solvable Model that Groks without Understanding

Noam Levi, Alon Beck, Yohai Bar-Sinai

TL;DR

The paper analyzes grokking in a minimal linear estimator by solving exact gradient-flow dynamics for a linear teacher–student model with Gaussian inputs. It shows that the delayed generalization can arise purely from covariance-driven dynamics, with grokking time primarily determined by $\lambda = \frac{d_{\mathrm{in}}}{N_{\mathrm{tr}}}$ and modulated by initialization, output dimension, and weight decay, rather than any qualitative shift to 'understanding'. The authors further develop semi-analytic results extended to 2-layer linear networks and provide evidence that some predictions persist under certain nonlinear activations in an NTK-like regime. Overall, the work offers a rigorous, interpretable framework linking dataset statistics to learning dynamics and clarifies how accuracy thresholds can mislead interpretations of grokking in neural networks.

Abstract

Grokking is the intriguing phenomenon where a model learns to generalize long after it has fit the training data. We show both analytically and numerically that grokking can surprisingly occur in linear networks performing linear tasks in a simple teacher-student setup with Gaussian inputs. In this setting, the full training dynamics is derived in terms of the training and generalization data covariance matrix. We present exact predictions on how the grokking time depends on input and output dimensionality, train sample size, regularization, and network initialization. We demonstrate that the sharp increase in generalization accuracy may not imply a transition from "memorization" to "understanding", but can simply be an artifact of the accuracy measure. We provide empirical verification for our calculations, along with preliminary results indicating that some predictions also hold for deeper networks, with non-linear activations.

Grokking in Linear Estimators -- A Solvable Model that Groks without Understanding

TL;DR

The paper analyzes grokking in a minimal linear estimator by solving exact gradient-flow dynamics for a linear teacher–student model with Gaussian inputs. It shows that the delayed generalization can arise purely from covariance-driven dynamics, with grokking time primarily determined by and modulated by initialization, output dimension, and weight decay, rather than any qualitative shift to 'understanding'. The authors further develop semi-analytic results extended to 2-layer linear networks and provide evidence that some predictions persist under certain nonlinear activations in an NTK-like regime. Overall, the work offers a rigorous, interpretable framework linking dataset statistics to learning dynamics and clarifies how accuracy thresholds can mislead interpretations of grokking in neural networks.

Abstract

Grokking is the intriguing phenomenon where a model learns to generalize long after it has fit the training data. We show both analytically and numerically that grokking can surprisingly occur in linear networks performing linear tasks in a simple teacher-student setup with Gaussian inputs. In this setting, the full training dynamics is derived in terms of the training and generalization data covariance matrix. We present exact predictions on how the grokking time depends on input and output dimensionality, train sample size, regularization, and network initialization. We demonstrate that the sharp increase in generalization accuracy may not imply a transition from "memorization" to "understanding", but can simply be an artifact of the accuracy measure. We provide empirical verification for our calculations, along with preliminary results indicating that some predictions also hold for deeper networks, with non-linear activations.
Paper Structure (23 sections, 39 equations, 6 figures)

This paper contains 23 sections, 39 equations, 6 figures.

Figures (6)

  • Figure 1: Grokking as a function of $\lambda$. Left: Empirical results for training (dashed) and generalization (solid) losses, for $\lambda=0.1,0.9,1.5$ (red, blue, violet) against analytical solutions (black). Center: Similar comparison for the accuracy functions. Right: The grokking time as a function of $\lambda$, for different values of the threshold parameter $\epsilon$. Different solid curves are numerical solutions for the expressions given in \ref{['sec:simplified']}, shown against the analytic solution in \ref{['eq:t_grok_anal']} (dashed black). In all three panels, diamonds/stars indicate training/generalization accuracy convergence to 95%. Training is done using GD with $\eta=\eta_0=0.01, d_\mathrm{in}=\!10^{3}, d_\mathrm{out}=1, \epsilon = \! 10^{-3}$.
  • Figure 2: Effects of the output dimension $d_\mathrm{out}>1$ on grokking. Left: Empirical results for training (dashed) and generalization (solid) losses, for $d_\mathrm{out}=1,50, 700$ (blue, red, violet) against analytical solutions (black), for $\lambda=0.9$. Center: Similar comparison for the accuracy functions. Right: The grokking time as a function of $d_\mathrm{out}$, for different values of $\lambda$. Different solid curves are numerical solutions for the expressions given in \ref{['sec:d_out_main']}. In all three panels, diamonds/stars indicate training/generalization accuracy convergence to 95%, shown for $d_\mathrm{out}^\mathrm{max}\simeq 50$, where the grokking time is maximal. Training is done using GD with $\eta=\eta_0=0.01, d_\mathrm{in}=\!10^{3}, \epsilon = \! 10^{-3}$.
  • Figure 3: Effects of weight decay ($\gamma$) on grokking. Left: Empirical results for training (dashed) and generalization (solid) losses, for $\gamma=10^{-5},10^{-3}, 10^{-2}$ (blue, red, violet) against analytical solutions (black), for $\lambda=0.9$. Center: Similar comparison for the accuracy functions. Right: The grokking time as a function of $\gamma$, for different values of $\lambda$. Different solid curves are numerical solutions for the expressions given in \ref{['sec:d_out_main']}, while the shaded gray region corresponds to training/generalization saturation, without perfect generalization. In all three panels, diamonds/stars indicate the point where accuracy reaches 95%. Training is done using GD with $\eta=\eta_0=0.01, d_\mathrm{in}=\!10^{3}, d_\mathrm{out}=1, \epsilon = \! 10^{-3}$
  • Figure 4: Grokking time phase diagrams. Left: A contour plot of the grokking time difference as a function of $\gamma, d_\mathrm{out}$. Shades of red indicate shorter grokking time, while blue tones indicate longer grokking time. White regions indicate no grokking, as generalization accuracy does not converge to $95\%$. Center and Right: Similar phase diagrams for the grokking time difference as a function of $\gamma, \lambda$ and $d_\mathrm{out}, \lambda$, respectively. The results of all three plots are obtained by numerically finding the grokking time, using the definition $\mathcal{A}(t^*)=0.95$ and the analytic formulas quoted in the main text. The fixed parameters for these plots are $\eta_0=0.01, \epsilon = \! 10^{-3}$.
  • Figure 5: 2-Layer network and nonlinearities. Top row: Empirical results for training (dashed) and generalization (solid) losses/accuracies ( left/right), for a two layer MLP (1000-$d_h$-5) with linear activations and $d_h=50,200$ (blue, red), against analytical solutions (black). Bottom row: Similar results, for a two layer MLP (1000-$d_h$-5) with $\tanh$ activations in the hidden layer. In both cases, training is done using full batch gradient descent with $\eta=\eta_0=0.01, d_\mathrm{in}=1000, d_\mathrm{out}=5, \epsilon = 10^{-4}$.
  • ...and 1 more figures