Table of Contents
Fetching ...

Latent Algorithmic Structure Precedes Grokking: A Mechanistic Study of ReLU MLPs on Modular Arithmetic

Anand Swaroop

Abstract

Grokking-the phenomenon where validation accuracy of neural networks on modular addition of two integers rises long after training data has been memorized-has been characterized in previous works as producing sinusoidal input weight distributions in transformers and multi-layer perceptrons (MLPs). We find empirically that ReLU MLPs in our experimental setting instead learn near-binary square wave input weights, where intermediate-valued weights appear exclusively near sign-change boundaries, alongside output weight distributions whose dominant Fourier phases satisfy a phase-sum relation $φ_{\mathrm{out}} = φ_a + φ_b$; this relation holds even when the model is trained on noisy data and fails to grok. We extract the frequency and phase of each neuron's weights via DFT and construct an idealized MLP: Input weights are replaced by perfect binary square waves and output weights by cosines, both parametrized by the frequencies, phases, and amplitudes extracted from the dominant Fourier components of the real model weights. This idealized model achieves 95.5% accuracy when the frequencies and phases are extracted from the weights of a model trained on noisy data that itself achieves only 0.23% accuracy. This suggests that grokking does not discover the correct algorithm, but rather sharpens an algorithm substantially encoded during memorization, progressively binarizing the input weights into cleaner square waves and aligning the output weights, until generalization becomes possible.

Latent Algorithmic Structure Precedes Grokking: A Mechanistic Study of ReLU MLPs on Modular Arithmetic

Abstract

Grokking-the phenomenon where validation accuracy of neural networks on modular addition of two integers rises long after training data has been memorized-has been characterized in previous works as producing sinusoidal input weight distributions in transformers and multi-layer perceptrons (MLPs). We find empirically that ReLU MLPs in our experimental setting instead learn near-binary square wave input weights, where intermediate-valued weights appear exclusively near sign-change boundaries, alongside output weight distributions whose dominant Fourier phases satisfy a phase-sum relation ; this relation holds even when the model is trained on noisy data and fails to grok. We extract the frequency and phase of each neuron's weights via DFT and construct an idealized MLP: Input weights are replaced by perfect binary square waves and output weights by cosines, both parametrized by the frequencies, phases, and amplitudes extracted from the dominant Fourier components of the real model weights. This idealized model achieves 95.5% accuracy when the frequencies and phases are extracted from the weights of a model trained on noisy data that itself achieves only 0.23% accuracy. This suggests that grokking does not discover the correct algorithm, but rather sharpens an algorithm substantially encoded during memorization, progressively binarizing the input weights into cleaner square waves and aligning the output weights, until generalization becomes possible.
Paper Structure (8 sections, 3 equations, 5 figures, 2 tables)

This paper contains 8 sections, 3 equations, 5 figures, 2 tables.

Figures (5)

  • Figure 1: Actual $W_a$ for various neurons plotted in blue, ideal square wave in gray. Intermediate values appear only near sign-change boundaries. Figures are scaled horizontally such that no more than 5 periods are visible in order to show fine details.
  • Figure 2: Actual $W_b$ from various neurons in red, ideal wave in gray. Intermediate values only near sign-change boundaries, and the dominant frequency extracted from $W_b$ matches that of $W_a$.
  • Figure 3: Accuracy and periodicity across noise levels.
  • Figure 4: $\phi_{\mathrm{out}}$ against $\phi_a + \phi_b$ across all neurons in the clean model, wrapped to $[-\pi, \pi]$. When restricted to only structured neurons, there is a nearly perfect correlation (circular correlation $r = 0.9993$).
  • Figure 5: $\phi_{\mathrm{out}}$ against $\phi_a + \phi_b$ for various $\alpha$ levels. The highest-periodicity neurons in each model lie close to the diagonal, regardless of the spread of the lower-periodicity neurons. Circular correlation coefficients are listed; $r$ is taken over all neurons, $r'$ over only structured neurons (periodicity $> 12$). $r' \ge 0.998$ for all $\alpha$ levels, indicating that structured neurons satisfy the phase-sum relation regardless of noise.