Provable Scaling Laws of Feature Emergence from Learning Dynamics of Grokking
Yuandong Tian
TL;DR
The paper introduces Li$_2$, a principled gradient-dynamics framework that decomposes grokking into three stages: lazy learning, independent feature learning, and interactive feature learning. It shows how leaked gradient signals in Stage I trigger Stage II’s energy-driven, nodewise feature emergence, and how Stage III interactions promote diversity and refinement via repulsion and top-down modulation, with Muon accelerating exploration. The framework provides provable scaling laws for when features generalize versus memorize, characterizes local maxima of an energy landscape tied to nonlinear canonical correlation, and extends to deeper architectures. It also explains the role of hyperparameters (weight decay, learning rate, data size, Muon) in shaping grokking and provides a path toward first-principles understanding of feature emergence in structured-input settings. The results unify group-theoretic structure with gradient dynamics to account for efficient feature representations and generalization under data constraints, with practical implications for optimizer design and data-efficiency in structured tasks.
Abstract
While the phenomenon of grokking, i.e., delayed generalization, has been studied extensively, it remains an open problem whether there is a mathematical framework that characterizes what kind of features will emerge, how and in which conditions it happens, and is closely related to the gradient dynamics of the training, for complex structured inputs. We propose a novel framework, named $\mathbf{Li}_2$, that captures three key stages for the grokking behavior of 2-layer nonlinear networks: (I) Lazy learning, (II) independent feature learning and (III) interactive feature learning. At the lazy learning stage, top layer overfits to random hidden representation and the model appears to memorize, and at the same time, the backpropagated gradient $G_F$ from the top layer now carries information about the target label, with a specific structure that enables each hidden node to learn their representation independently. Interestingly, the independent dynamics follows exactly the gradient ascent of an energy function $E$, and its local maxima are precisely the emerging features. We study whether these local-optima induced features are generalizable, their representation power, and how they change on sample size, in group arithmetic tasks. When hidden nodes start to interact in the later stage of learning, we provably show how $G_F$ changes to focus on missing features that need to be learned. Our study sheds lights on roles played by key hyperparameters such as weight decay, learning rate and sample sizes in grokking, leads to provable scaling laws of feature emergence, memorization and generalization, and reveals why recent optimizers such as Muon can be effective, from the first principles of gradient dynamics. Our analysis can be extended to multi-layers. The code is available at https://github.com/yuandong-tian/understanding/tree/main/ssl/real-dataset/cogo.
