Table of Contents
Fetching ...

Grokking modular arithmetic

Andrey Gromov

TL;DR

This work investigates grokking—the abrupt rise in generalization after substantial training—in a minimal two-layer MLP learning modular arithmetic. It provides analytic weight constructions based on Fourier-style feature maps that implement modular addition, and shows that vanilla gradient descent and AdamW discover these features, achieving 100% test accuracy under suitable data regimes. The study connects grokking to explicit feature learning, demonstrates interpretability of learned representations, and extends insights to broader modular functions while examining scaling with width and data. It also discusses limitations and future directions, including complexity measures, deeper architectures, and practical applications in cryptography and algorithmic tasks.

Abstract

We present a simple neural network that can learn modular arithmetic tasks and exhibits a sudden jump in generalization known as ``grokking''. Concretely, we present (i) fully-connected two-layer networks that exhibit grokking on various modular arithmetic tasks under vanilla gradient descent with the MSE loss function in the absence of any regularization; (ii) evidence that grokking modular arithmetic corresponds to learning specific feature maps whose structure is determined by the task; (iii) analytic expressions for the weights -- and thus for the feature maps -- that solve a large class of modular arithmetic tasks; and (iv) evidence that these feature maps are also found by vanilla gradient descent as well as AdamW, thereby establishing complete interpretability of the representations learnt by the network.

Grokking modular arithmetic

TL;DR

This work investigates grokking—the abrupt rise in generalization after substantial training—in a minimal two-layer MLP learning modular arithmetic. It provides analytic weight constructions based on Fourier-style feature maps that implement modular addition, and shows that vanilla gradient descent and AdamW discover these features, achieving 100% test accuracy under suitable data regimes. The study connects grokking to explicit feature learning, demonstrates interpretability of learned representations, and extends insights to broader modular functions while examining scaling with width and data. It also discusses limitations and future directions, including complexity measures, deeper architectures, and practical applications in cryptography and algorithmic tasks.

Abstract

We present a simple neural network that can learn modular arithmetic tasks and exhibits a sudden jump in generalization known as ``grokking''. Concretely, we present (i) fully-connected two-layer networks that exhibit grokking on various modular arithmetic tasks under vanilla gradient descent with the MSE loss function in the absence of any regularization; (ii) evidence that grokking modular arithmetic corresponds to learning specific feature maps whose structure is determined by the task; (iii) analytic expressions for the weights -- and thus for the feature maps -- that solve a large class of modular arithmetic tasks; and (iv) evidence that these feature maps are also found by vanilla gradient descent as well as AdamW, thereby establishing complete interpretability of the representations learnt by the network.
Paper Structure (15 sections, 24 equations, 10 figures)

This paper contains 15 sections, 24 equations, 10 figures.

Figures (10)

  • Figure 1: Dynamics under GD for the minimal model \ref{['eq_f']} with MSE loss and $\alpha = 0.49$. (a) Train and test loss. Train loss generally decays monotonically, while test loss reaches its maximum right before the onset of grokking. (b) Norms of weight matrices during training. We do not observer a large increase in weight norms as in thilak2022slingshot, but we do see that weight norms start growing at the onset of grokking. (c) Train and test accuracy showing the delayed and sudden onset of generalization. (d) Norms of gradient vectors. The dynamics accelerates until the test loss maximum is reached and then slowly decelerates.
  • Figure 2: Preactivations. First row: Preactivation $h^{(2)}_6(n,m)$. Second row: Fourier image of the Preactivation $h^{(2)}_6(n,m)$. Third row: Preactivation $h^{(1)}_6(n,m)$ or $h^{(1)}_{30}(n,m)$. First column: At initialization. Second column: Found by vanilla GD. The Fourier image shows a single series of peaks corresponding to $m+n = 6 \,\,\textrm{mod}\,\, 97$. Third column: Evaluated using the analytic solution \ref{['eq_sol1']}-\ref{['eq_sol2']}. The Fourier image shows the same peak as found by GD, but also weak peaks corresponding to $2m = 6 \,\,\textrm{mod}\,\, 97$, $2n = 6 \,\,\textrm{mod}\,\, 97$ and $m - n = 6 \,\,\textrm{mod}\,\, 97$ that were suppresed by the choice of phases via \ref{['eq_phases']}.
  • Figure 3: Solutions found by the optimizer. In all cases distribution of $\varphi^{(1)}_k + \varphi^{(2)}_k - \varphi^{(3)}_k$ is strongly peaked around $0$. The solutions found by AdamW are closer to the analytic ones because the phases are peaked stronger around $0$. Note that for solutions found by the optimizer the phases are not iid which leads to the better accuracy.
  • Figure 4: Scaling with width and data. (a) Grokking time vs. the amount of training data for various optimizers. The abrupt change in grokking time is observed at different $\alpha$. Momentum appears to play a major role both in reducing grokking time and $\alpha$. (b): Test accuracy as a function of width for the solution found by GD, AdamW and for the analytic solution \ref{['eq_sol1']}--\ref{['eq_sol2']}. The optimizer can tune phases better than random uniform distribution in order to ensure better cancellations. The shape of the curves also depends on the amount of data used for training and number of epochs. Here we took $\alpha = 0.5$ and trained longer for GD.
  • Figure 5: Inverse participation ratio. IPR plotted against the dynamics (under AdamW) of train and test accuracy. Empirically, we see $4$ regimes: (i) early training when IPR grows linearly and slowly; (ii) transition from slow liner growth to fast linear growth. This transition period coincides with grokking; (iii) fast linear growth, that starts after $100\%$ test accuracy was reached; (iv) saturation, once weights became periodic. The dashed line indicates $\overline{\textrm{IPR}}_2$ for the exact solution \ref{['eq_sol1']}-\ref{['eq_sol2']}. The gap between the two indicates that even in the final solution there is quite a bit of noise leading do some mild delocalization in Fourier space. More training and more data helps to reduce the gap.
  • ...and 5 more figures