Mechanistic Insights into Grokking from the Embedding Layer
H. V. AlquBoj, Hilal AlQuabeh, Velibor Bojkovic, Munachiso Nwadike, Kentaro Inui
TL;DR
This work analyzes grokking and delayed generalization, arguing that trainable embedding layers induce the phenomenon in MLPs solving modular arithmetic tasks. It identifies two mechanisms—sparse embedding updates for rare tokens and bilinear coupling between embeddings and first-layer weights that creates saddle points—and proposes frequency-aware sampling and embedding-specific learning-rate adjustments, notably an adaptive scheme with $c = \frac{\eta_E}{\eta_W}$ proportional to $\frac{\sigma_{\max}(\mathbf{E})}{\sigma_{\max}(\mathbf{W})} \cdot \frac{f_W}{f_E}$, empirically set to $10$, to accelerate convergence. The study also demonstrates that balancing update dynamics via Adam-LR improves stability and generalization, with Hessian analyses showing more balanced curvature between embedding and downstream weights. While focused on MLPs, the findings have broader implications for Transformer optimization where bilinear interactions also play a role, offering practical strategies to mitigate grokking and enhance generalization in bilinear systems.
Abstract
Grokking, a delayed generalization in neural networks after perfect training performance, has been observed in Transformers and MLPs, but the components driving it remain underexplored. We show that embeddings are central to grokking: introducing them into MLPs induces delayed generalization in modular arithmetic tasks, whereas MLPs without embeddings can generalize immediately. Our analysis identifies two key mechanisms: (1) Embedding update dynamics, where rare tokens stagnate due to sparse gradient updates and weight decay, and (2) Bilinear coupling, where the interaction between embeddings and downstream weights introduces saddle points and increases sensitivity to initialization. To confirm these mechanisms, we investigate frequency-aware sampling, which balances token updates by minimizing gradient variance, and embedding-specific learning rates, derived from the asymmetric curvature of the bilinear loss landscape. We prove that an adaptive learning rate ratio, \(\frac{η_E}{η_W} \propto \frac{σ_{\max}(E)}{σ_{\max}(W)} \cdot \frac{f_W}{f_E}\), mitigates bilinear coupling effects, accelerating convergence. Our methods not only improve grokking dynamics but also extend to broader challenges in Transformer optimization, where bilinear interactions hinder efficient training.
