Table of Contents
Fetching ...

NeuralGrok: Accelerate Grokking by Neural Gradient Transformation

Xinyu Zhou, Simin Fan, Martin Jaggi, Jie Fu

TL;DR

This paper tackles the grokking phenomenon in transformer models by introducing NeuralGrok, a bilevel gradient-transformation framework where a neural-amplifier learns to reweight gradient components to promote generalization. Through inner-outer loop optimization, the amplifier guides the base model toward faster, more stable generalization on arithmetic tasks, outperforming standard training and GrokFast baselines. The authors further show that NeuralGrok reduces model and gradient complexity, as measured by the Absolute Gradient Entropy metric, and that gradient rescaling can stabilize training even without learned transformations. While the approach yields clear benefits on synthetic arithmetic tasks, it also reveals limitations such as task-specific transferability of the gradient transformations and the need to validate on broader domains.

Abstract

Grokking is proposed and widely studied as an intricate phenomenon in which generalization is achieved after a long-lasting period of overfitting. In this work, we propose NeuralGrok, a novel gradient-based approach that learns an optimal gradient transformation to accelerate the generalization of transformers in arithmetic tasks. Specifically, NeuralGrok trains an auxiliary module (e.g., an MLP block) in conjunction with the base model. This module dynamically modulates the influence of individual gradient components based on their contribution to generalization, guided by a bilevel optimization algorithm. Our extensive experiments demonstrate that NeuralGrok significantly accelerates generalization, particularly in challenging arithmetic tasks. We also show that NeuralGrok promotes a more stable training paradigm, constantly reducing the model's complexity, while traditional regularization methods, such as weight decay, can introduce substantial instability and impede generalization. We further investigate the intrinsic model complexity leveraging a novel Absolute Gradient Entropy (AGE) metric, which explains that NeuralGrok effectively facilitates generalization by reducing the model complexity. We offer valuable insights on the grokking phenomenon of Transformer models, which encourages a deeper understanding of the fundamental principles governing generalization ability.

NeuralGrok: Accelerate Grokking by Neural Gradient Transformation

TL;DR

This paper tackles the grokking phenomenon in transformer models by introducing NeuralGrok, a bilevel gradient-transformation framework where a neural-amplifier learns to reweight gradient components to promote generalization. Through inner-outer loop optimization, the amplifier guides the base model toward faster, more stable generalization on arithmetic tasks, outperforming standard training and GrokFast baselines. The authors further show that NeuralGrok reduces model and gradient complexity, as measured by the Absolute Gradient Entropy metric, and that gradient rescaling can stabilize training even without learned transformations. While the approach yields clear benefits on synthetic arithmetic tasks, it also reveals limitations such as task-specific transferability of the gradient transformations and the need to validate on broader domains.

Abstract

Grokking is proposed and widely studied as an intricate phenomenon in which generalization is achieved after a long-lasting period of overfitting. In this work, we propose NeuralGrok, a novel gradient-based approach that learns an optimal gradient transformation to accelerate the generalization of transformers in arithmetic tasks. Specifically, NeuralGrok trains an auxiliary module (e.g., an MLP block) in conjunction with the base model. This module dynamically modulates the influence of individual gradient components based on their contribution to generalization, guided by a bilevel optimization algorithm. Our extensive experiments demonstrate that NeuralGrok significantly accelerates generalization, particularly in challenging arithmetic tasks. We also show that NeuralGrok promotes a more stable training paradigm, constantly reducing the model's complexity, while traditional regularization methods, such as weight decay, can introduce substantial instability and impede generalization. We further investigate the intrinsic model complexity leveraging a novel Absolute Gradient Entropy (AGE) metric, which explains that NeuralGrok effectively facilitates generalization by reducing the model complexity. We offer valuable insights on the grokking phenomenon of Transformer models, which encourages a deeper understanding of the fundamental principles governing generalization ability.

Paper Structure

This paper contains 30 sections, 6 equations, 12 figures, 1 table, 2 algorithms.

Figures (12)

  • Figure 1: Train and Test accuracies on arithmetic tasks.NeuralGrok consistently accelerates generalization under the grokking phenomenon, especially on the challenging task.
  • Figure 2: Standard training with various gradient rescaling coefficient $c$ on (a+b) mod 97 task. With $c$=$0.5,1.0,2.0$, the training is effectively stabilized compared to \ref{['fig:all_tasks_std']} with unchanged magnitude.
  • Figure 3: Standard training with standard gradient normalization $c$=$1.0$. The gradient normalization enables the generalization on the challenging task (a$\times$c+b$\times$d-e) mod 7, which is failed in \ref{['fig:all_tasks_std']}, when gradient normalization is not applied.
  • Figure 4: NeuralGrok with various gradient rescaling coefficient $c$ on the (a+b) mod 97 task. The model can achieve a perfect test accuracy in similar speed ($1.3k steps$) with $c$ ranging from $0.2$ to $1.0$. While applying a larger gradient magnitude $c$=$2.0$ could lead to a delayed generalization.
  • Figure 5: Model complexity measured in entropy on task (a+b) mod 97. The transition windows for Memorization and Generalization phases are marked in red and green.
  • ...and 7 more figures