Grokking modular arithmetic
Andrey Gromov
TL;DR
This work investigates grokking—the abrupt rise in generalization after substantial training—in a minimal two-layer MLP learning modular arithmetic. It provides analytic weight constructions based on Fourier-style feature maps that implement modular addition, and shows that vanilla gradient descent and AdamW discover these features, achieving 100% test accuracy under suitable data regimes. The study connects grokking to explicit feature learning, demonstrates interpretability of learned representations, and extends insights to broader modular functions while examining scaling with width and data. It also discusses limitations and future directions, including complexity measures, deeper architectures, and practical applications in cryptography and algorithmic tasks.
Abstract
We present a simple neural network that can learn modular arithmetic tasks and exhibits a sudden jump in generalization known as ``grokking''. Concretely, we present (i) fully-connected two-layer networks that exhibit grokking on various modular arithmetic tasks under vanilla gradient descent with the MSE loss function in the absence of any regularization; (ii) evidence that grokking modular arithmetic corresponds to learning specific feature maps whose structure is determined by the task; (iii) analytic expressions for the weights -- and thus for the feature maps -- that solve a large class of modular arithmetic tasks; and (iv) evidence that these feature maps are also found by vanilla gradient descent as well as AdamW, thereby establishing complete interpretability of the representations learnt by the network.
