Table of Contents
Fetching ...

Predicting Grokking Long Before it Happens: A look into the loss landscape of models which grok

Pascal Jr. Tikeng Notsawo, Hattie Zhou, Mohammad Pezeshki, Irina Rish, Guillaume Dumas

TL;DR

<3-5 sentence high-level summary>Predicting Grokking Long Before it Happens introduces a low-cost predictor for grokking by examining early training dynamics. The authors propose a spectral signature derived from the Fourier transform of the initial loss curve, quantified via Hjorth activity, to forecast whether grokking will occur later, and they show a power-law relation $t_4(r)=a r^{-\gamma}+b$ with training-data fraction $r$. They analyze the grokking loss landscape through 1D projections and Hessian eigenvalues, unveiling a perturbed, ill-conditioned terrain and a slingshot mechanism that accompanies delayed generalization. The study bridges learning dynamics, landscape geometry, and generalization, offering practical guidance for hyperparameter selection and insights that may generalize to larger models and other domains.

Abstract

This paper focuses on predicting the occurrence of grokking in neural networks, a phenomenon in which perfect generalization emerges long after signs of overfitting or memorization are observed. It has been reported that grokking can only be observed with certain hyper-parameters. This makes it critical to identify the parameters that lead to grokking. However, since grokking occurs after a large number of epochs, searching for the hyper-parameters that lead to it is time-consuming. In this paper, we propose a low-cost method to predict grokking without training for a large number of epochs. In essence, by studying the learning curve of the first few epochs, we show that one can predict whether grokking will occur later on. Specifically, if certain oscillations occur in the early epochs, one can expect grokking to occur if the model is trained for a much longer period of time. We propose using the spectral signature of a learning curve derived by applying the Fourier transform to quantify the amplitude of low-frequency components to detect the presence of such oscillations. We also present additional experiments aimed at explaining the cause of these oscillations and characterizing the loss landscape.

Predicting Grokking Long Before it Happens: A look into the loss landscape of models which grok

TL;DR

<3-5 sentence high-level summary>Predicting Grokking Long Before it Happens introduces a low-cost predictor for grokking by examining early training dynamics. The authors propose a spectral signature derived from the Fourier transform of the initial loss curve, quantified via Hjorth activity, to forecast whether grokking will occur later, and they show a power-law relation with training-data fraction . They analyze the grokking loss landscape through 1D projections and Hessian eigenvalues, unveiling a perturbed, ill-conditioned terrain and a slingshot mechanism that accompanies delayed generalization. The study bridges learning dynamics, landscape geometry, and generalization, offering practical guidance for hyperparameter selection and insights that may generalize to larger models and other domains.

Abstract

This paper focuses on predicting the occurrence of grokking in neural networks, a phenomenon in which perfect generalization emerges long after signs of overfitting or memorization are observed. It has been reported that grokking can only be observed with certain hyper-parameters. This makes it critical to identify the parameters that lead to grokking. However, since grokking occurs after a large number of epochs, searching for the hyper-parameters that lead to it is time-consuming. In this paper, we propose a low-cost method to predict grokking without training for a large number of epochs. In essence, by studying the learning curve of the first few epochs, we show that one can predict whether grokking will occur later on. Specifically, if certain oscillations occur in the early epochs, one can expect grokking to occur if the model is trained for a much longer period of time. We propose using the spectral signature of a learning curve derived by applying the Fourier transform to quantify the amplitude of low-frequency components to detect the presence of such oscillations. We also present additional experiments aimed at explaining the cause of these oscillations and characterizing the loss landscape.
Paper Structure (31 sections, 6 equations, 30 figures, 1 algorithm)

This paper contains 31 sections, 6 equations, 30 figures, 1 algorithm.

Figures (30)

  • Figure 1: Oscillation in training and validation accuracies. We train on modular addition with $r=0.5$. The left curve shows a case where the model did not grok after 10k steps of training. The right curve shows the generalization after overfitting. Training accuracy becomes close to perfect at $t_2 < 300$ optimization steps, but it takes close to $t_4 \approx 7k$ steps for validation accuracy to reach that level.
  • Figure 2: The first figure (top) represents the validation accuracy (%) at the end of the training ($10k$ steps), and the second figure (bottom) represents the spectral energy (activity) in the training loss for the first 400 training steps ($r = 0.5$). On the x-axis we have the weight decay strength, and on the y-axis we have the learning rate. A similarity is observed between the oscillation patterns in the training loss during the initial stages of training and the validation accuracy. This suggests that the spectral signature can be used as an indicator or proxy for the upcoming grokking phenomenon. The highest degree of generalization is typically observed when using small learning rates and small weight decay. While large learning rates may increase oscillations, this does not directly lead to grokking and is not necessarily evident in the early stages of training. Instead, such effects become more noticeable near the basin of attraction of the minimum.
  • Figure 3: 1D projection of the grokking loss and accuracy surface ($r=0.3$). The x-axis is $\alpha \in [-3, 3]$, and the y-axis are the loss (left axis) and accuracy (right axis) at $\theta^* + \alpha \delta$ with $\delta \propto \theta_0 - \theta^*$ (with an adaptation of the filter-wise normalization DBLP:conf/nips/Li0TSG18), where $\theta_0$ is the initial parameter and $\theta^*$ the parameter just after the model has grokked (a) and before grokking (b). We can see that the 1-D subspace from initialization to grokking contains many difficult and exotic structures. In (a), we have two minimizers of the training loss (solid line), but only one of them also minimizes the validation loss ().
  • Figure 4: 1D projection of the grokking loss and accuracy surface. This corresponds to figure \ref{['fig:cover']}.a, but for several training epochs ($r=0.3$). The direction $\vec{\delta}_t$ used for each training epoch $t$ is the unit vector of $\theta^* - \theta_t$ (filter-wise normalization DBLP:conf/nips/Li0TSG18), the direction from the parameter at epoch $t$ to the minimum. Here, the structure is more exotic. We can clearly see two minimizers of the training loss, but only one minimizes the validation loss: during memorization, the model is in this local minimum, and it achieves grokking when it successfully breaks free from this local solution.
  • Figure 5: Similar to figure \ref{['fig:30_t_T']}, but with $r=0.8$. What changes between the two figures is the time it takes for the model to reach the global minimum. This time decreases as $\Theta \left( 1/ r ^{\gamma} \right)$ with $r$, $\gamma > 0$ (section \ref{['subsec:grokking_appendix']}).
  • ...and 25 more figures