Table of Contents
Fetching ...

Egalitarian Gradient Descent: A Simple Approach to Accelerated Grokking

Ali Saheb Pasand, Elvis Dohmatob

TL;DR

This work addresses the delayed generalization known as grokking by examining gradient spectra and introducing Egalitarian Gradient Descent (EGD), a simple gradient-whitening technique that equalizes progression across principal directions. By transforming the gradient as $\tilde{G}=(GG^T)^{-1/2}G$, EGD yields isotropic updates and provably accelerates grokking, with connections to natural gradient descent. The authors provide theoretical insights and demonstrate empirical gains on sparse parity and modular arithmetic tasks, showing substantially faster or even immediate generalization without harming final performance. The method is hyperparameter-free, memory-efficient, and easy to integrate, offering practical impact for reducing training plateaus in a range of settings. Limitations include computational overhead from spectral decompositions, motivating future work on scalable spectral approximations and broader benchmarks.

Abstract

Grokking is the phenomenon whereby, unlike the training performance, which peaks early in the training process, the test/generalization performance of a model stagnates over arbitrarily many epochs and then suddenly jumps to usually close to perfect levels. In practice, it is desirable to reduce the length of such plateaus, that is to make the learning process "grok" faster. In this work, we provide new insights into grokking. First, we show both empirically and theoretically that grokking can be induced by asymmetric speeds of (stochastic) gradient descent, along different principal (i.e singular directions) of the gradients. We then propose a simple modification that normalizes the gradients so that dynamics along all the principal directions evolves at exactly the same speed. Then, we establish that this modified method, which we call egalitarian gradient descent (EGD) and can be seen as a carefully modified form of natural gradient descent, groks much faster. In fact, in some cases the stagnation is completely removed. Finally, we empirically show that on classical arithmetic problems such as modular addition and sparse parity problem which this stagnation has been widely observed and intensively studied, that our proposed method eliminates the plateaus.

Egalitarian Gradient Descent: A Simple Approach to Accelerated Grokking

TL;DR

This work addresses the delayed generalization known as grokking by examining gradient spectra and introducing Egalitarian Gradient Descent (EGD), a simple gradient-whitening technique that equalizes progression across principal directions. By transforming the gradient as , EGD yields isotropic updates and provably accelerates grokking, with connections to natural gradient descent. The authors provide theoretical insights and demonstrate empirical gains on sparse parity and modular arithmetic tasks, showing substantially faster or even immediate generalization without harming final performance. The method is hyperparameter-free, memory-efficient, and easy to integrate, offering practical impact for reducing training plateaus in a range of settings. Limitations include computational overhead from spectral decompositions, motivating future work on scalable spectral approximations and broader benchmarks.

Abstract

Grokking is the phenomenon whereby, unlike the training performance, which peaks early in the training process, the test/generalization performance of a model stagnates over arbitrarily many epochs and then suddenly jumps to usually close to perfect levels. In practice, it is desirable to reduce the length of such plateaus, that is to make the learning process "grok" faster. In this work, we provide new insights into grokking. First, we show both empirically and theoretically that grokking can be induced by asymmetric speeds of (stochastic) gradient descent, along different principal (i.e singular directions) of the gradients. We then propose a simple modification that normalizes the gradients so that dynamics along all the principal directions evolves at exactly the same speed. Then, we establish that this modified method, which we call egalitarian gradient descent (EGD) and can be seen as a carefully modified form of natural gradient descent, groks much faster. In fact, in some cases the stagnation is completely removed. Finally, we empirically show that on classical arithmetic problems such as modular addition and sparse parity problem which this stagnation has been widely observed and intensively studied, that our proposed method eliminates the plateaus.

Paper Structure

This paper contains 29 sections, 4 theorems, 25 equations, 6 figures, 1 table.

Key Result

Theorem 1

For large $n$, it holds w.h.p that: for any iteration $k \ge 1$,

Figures (6)

  • Figure 1: Results on Modular Addition for different values of the modulus $p$. Solid lines correspond to test accuracy and broken lines correspond to train accuracy. In all cases, our proposed EGD (egalitarian gradient descent) method groks after only a few epochs, while vanilla (stochastic) gradient descent stagnates for a long period before eventually grokking. We also include "Column Normalization", a simplification of EGD which simply rescales the columns of gradient matrices by dividing by their $L_2$ norm. Even this simplification seems to grok much faster than the baseline, vanilla (S)GD. Refer to Section \ref{['sec:exp']} for details and to Appendix \ref{['app:hyp']} for the hyper-parameters used.
  • Figure 2: Results on Modular Multiplication for different values of the modulus $p$. Solid lines correspond to test accuracy and broken lines correspond to train accuracy. In all cases, our proposed EGD method groks after only a few epochs, while all the other methods stagnate a long period before eventually grokking. Refer to Section \ref{['sec:exp']} for details and to Appendix \ref{['app:hyp']} for the hyperparameters used.
  • Figure 3: Results on Sparse Parity Problem. Solid lines correspond to test accuracy and broken lines correspond to train accuracy. All three plots show that our method (EGD) groks significantly faster than other methods. Refer to Section \ref{['sec:exp']} for details on the experimental setup and to Appendix \ref{['app:hyp']} for the hyperparameters used.
  • Figure 4: Ill-conditioned Gradient Spectra causes delayed generalization. We consider the problem of learning addition modulo 97 from data, with a two-layer ReLU neural network. At the start of optimization through to the end, the gradient matrix $G$ for the hidden layer has a poor condition number. Here, we see that the largest singular-value (corresponding to a fast direction) is much larger than the smallest (corresponding to slow directions). This causes the overall dynamics of vanilla (S)GD to stall for arbitrarily long times, leading to delayed generalization (see Figure \ref{['fig:modadd']}). Our proposed method, EGT (egalitarian gradient descent) forces all the singular values of $G$ to be equal.
  • Figure 5: A Toy Setup which Induces Stagnation in Gradient Descent (GD). Training data points correspond to circles and test data points correspond to stars (middle region). The broken lines correspond to the large margin of the training data (their separation is $2s$), while the solid line is the ground-truth decision-boundary $x^{(1)}=0$. The variance of the slow feature $x^{(2)}$ scales like $\varepsilon \ll 1$. GD would quickly find a linear model which perfectly separates the training data but will take a time of order $1/\varepsilon$ to find the ground-truth model which attains perfect test accuracy. See Figure \ref{['fig:plateau-length']}.
  • ...and 1 more figures

Theorems & Definitions (5)

  • Theorem 1
  • Corollary 1
  • Remark 1
  • Lemma 1
  • Lemma 2