Table of Contents
Fetching ...

Grokking at the Edge of Numerical Stability

Lucas Prieto, Melih Barsbey, Pedro A. M. Mediano, Tolga Birdal

TL;DR

The paper investigates why grokking—a sudden generalization after long overfitting—often requires regularization. It identifies Softmax Collapse (SC) as the key bottleneck in non-regularized setups, caused by floating-point absorption errors in the Softmax computation, and links SC to Naïve Loss Minimization (NLM), which scales logits without changing predictions and eventually triggers SC. To address this, the authors propose StableMax, a numerically stable Softmax variant, and $\perp\!\mathrm{Grad}$, an optimizer that eliminates the NLM direction, enabling grokking and faster generalization without weight decay. The results across modular arithmetic, sparse parity, and MNIST demonstrate that mitigating numerical instability and constraining gradient directions can explain and reproduce grokking-like generalization, offering a unified view of delayed generalization and the role of regularization in grokking.

Abstract

Grokking, the sudden generalization that occurs after prolonged overfitting, is a surprising phenomenon challenging our understanding of deep learning. Although significant progress has been made in understanding grokking, the reasons behind the delayed generalization and its dependence on regularization remain unclear. In this work, we argue that without regularization, grokking tasks push models to the edge of numerical stability, introducing floating point errors in the Softmax function, which we refer to as Softmax Collapse (SC). We demonstrate that SC prevents grokking and that mitigating SC enables grokking without regularization. Investigating the root cause of SC, we find that beyond the point of overfitting, the gradients strongly align with what we call the naïve loss minimization (NLM) direction. This component of the gradient does not alter the model's predictions but decreases the loss by scaling the logits, typically by scaling the weights along their current direction. We show that this scaling of the logits explains the delay in generalization characteristic of grokking and eventually leads to SC, halting further learning. To validate our hypotheses, we introduce two key contributions that address the challenges in grokking tasks: StableMax, a new activation function that prevents SC and enables grokking without regularization, and $\perp$Grad, a training algorithm that promotes quick generalization in grokking tasks by preventing NLM altogether. These contributions provide new insights into grokking, elucidating its delayed generalization, reliance on regularization, and the effectiveness of existing grokking-inducing methods. Code for this paper is available at https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability.

Grokking at the Edge of Numerical Stability

TL;DR

The paper investigates why grokking—a sudden generalization after long overfitting—often requires regularization. It identifies Softmax Collapse (SC) as the key bottleneck in non-regularized setups, caused by floating-point absorption errors in the Softmax computation, and links SC to Naïve Loss Minimization (NLM), which scales logits without changing predictions and eventually triggers SC. To address this, the authors propose StableMax, a numerically stable Softmax variant, and , an optimizer that eliminates the NLM direction, enabling grokking and faster generalization without weight decay. The results across modular arithmetic, sparse parity, and MNIST demonstrate that mitigating numerical instability and constraining gradient directions can explain and reproduce grokking-like generalization, offering a unified view of delayed generalization and the role of regularization in grokking.

Abstract

Grokking, the sudden generalization that occurs after prolonged overfitting, is a surprising phenomenon challenging our understanding of deep learning. Although significant progress has been made in understanding grokking, the reasons behind the delayed generalization and its dependence on regularization remain unclear. In this work, we argue that without regularization, grokking tasks push models to the edge of numerical stability, introducing floating point errors in the Softmax function, which we refer to as Softmax Collapse (SC). We demonstrate that SC prevents grokking and that mitigating SC enables grokking without regularization. Investigating the root cause of SC, we find that beyond the point of overfitting, the gradients strongly align with what we call the naïve loss minimization (NLM) direction. This component of the gradient does not alter the model's predictions but decreases the loss by scaling the logits, typically by scaling the weights along their current direction. We show that this scaling of the logits explains the delay in generalization characteristic of grokking and eventually leads to SC, halting further learning. To validate our hypotheses, we introduce two key contributions that address the challenges in grokking tasks: StableMax, a new activation function that prevents SC and enables grokking without regularization, and Grad, a training algorithm that promotes quick generalization in grokking tasks by preventing NLM altogether. These contributions provide new insights into grokking, elucidating its delayed generalization, reliance on regularization, and the effectiveness of existing grokking-inducing methods. Code for this paper is available at https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability.
Paper Structure (48 sections, 2 theorems, 18 equations, 17 figures, 1 table)

This paper contains 48 sections, 2 theorems, 18 equations, 17 figures, 1 table.

Key Result

Proposition 1

$\mathrm{StableMax}$ is a modified $\mathrm{Softmax}\xspace$, i.e. $\mathrm{StableMax}\left(x_i\right) = \mathrm{Softmax}\xspace\left(g\left(x_i\right)\right)$ where

Figures (17)

  • Figure 1: Our contributions demonstrated through results obtained in addition modulo 113 task. We show that the delay in generalization induced by NLM can be reversed using the proposed $\perp$ AdamW ((a) and (b)) and that the numerical errors that lead to overfitting instead of grokking can be avoided by using the proposed $\mathrm{StableMax}$ ((b) and (c)).
  • Figure 2: As dataset size increases (subplots a to c), MLPs trained on modular addition begin to generalize without regularization until this is stopped by SC making the gradient from a large fraction of the samples equal to zero. This stopping point comes earlier for $\mathrm{float32}$ than $\mathrm{float64}$ and with small enough datasets it comes before the model makes any progress on test accuracy.
  • Figure 3: $s(x)~\mathrm{vs.}~e^x$.
  • Figure 4: (left) Grokking with StCE loss and no regularization on three common grokking datasets using an MLP with 2 hidden layers of width 200. We use 40% of all pairs modulo 113 which is the same setting as \ref{['fig:grokking_stops_40']} where regular SCE gets stuck at random level performance (random level is 50% for sparse parity). (middle) Evolution of model weight norms during training for the same models and tasks. This shows that grokking induced without weight decay does not follow the commonly observed trend of rapidly decreasing weight norm during generalization. (right) Changing input representations turns modular addition into regular machine learning tasks with train and test accuracy increasing in tandem, see \ref{['sec:nlm']}.
  • Figure 5: MLPs with (a) and without (b) bias terms trained on modular addition receive updates that are significantly aligned with the direction of NLM beyond the point of overfitting. In (c) we show these results for a selection of parameters for our one layer transformer. We highlight the embed and unembed matrices as well as the weights of the MLP. These are highlighted in the plot using the notation from elhage2021mathematical.
  • ...and 12 more figures

Theorems & Definitions (12)

  • Definition 1: Absorption Errors
  • Definition 2: Softmax Cross-Entropy (SCE) loss
  • Definition 3: Softmax Collapse (SC)
  • Definition 4: $\mathrm{StableMax}$
  • Proposition 1
  • Definition 5: Naïve Loss Minimization (NLM)
  • Definition 6: Positive Homogeneity Lyu2019-sc
  • Definition 7: $\perp$Grad
  • Proposition 2
  • proof : Sketch of the proof.
  • ...and 2 more