Egalitarian Gradient Descent: A Simple Approach to Accelerated Grokking
Ali Saheb Pasand, Elvis Dohmatob
TL;DR
This work addresses the delayed generalization known as grokking by examining gradient spectra and introducing Egalitarian Gradient Descent (EGD), a simple gradient-whitening technique that equalizes progression across principal directions. By transforming the gradient as $\tilde{G}=(GG^T)^{-1/2}G$, EGD yields isotropic updates and provably accelerates grokking, with connections to natural gradient descent. The authors provide theoretical insights and demonstrate empirical gains on sparse parity and modular arithmetic tasks, showing substantially faster or even immediate generalization without harming final performance. The method is hyperparameter-free, memory-efficient, and easy to integrate, offering practical impact for reducing training plateaus in a range of settings. Limitations include computational overhead from spectral decompositions, motivating future work on scalable spectral approximations and broader benchmarks.
Abstract
Grokking is the phenomenon whereby, unlike the training performance, which peaks early in the training process, the test/generalization performance of a model stagnates over arbitrarily many epochs and then suddenly jumps to usually close to perfect levels. In practice, it is desirable to reduce the length of such plateaus, that is to make the learning process "grok" faster. In this work, we provide new insights into grokking. First, we show both empirically and theoretically that grokking can be induced by asymmetric speeds of (stochastic) gradient descent, along different principal (i.e singular directions) of the gradients. We then propose a simple modification that normalizes the gradients so that dynamics along all the principal directions evolves at exactly the same speed. Then, we establish that this modified method, which we call egalitarian gradient descent (EGD) and can be seen as a carefully modified form of natural gradient descent, groks much faster. In fact, in some cases the stagnation is completely removed. Finally, we empirically show that on classical arithmetic problems such as modular addition and sparse parity problem which this stagnation has been widely observed and intensively studied, that our proposed method eliminates the plateaus.
