Table of Contents
Fetching ...

Deep Networks Always Grok and Here is Why

Ahmed Imtiaz Humayun, Randall Balestriero, Richard Baraniuk

TL;DR

This work shows that grokking—delayed generalization—occurs broadly across practical DNNs, not just in contrived setups. It introduces Local Complexity (LC), a spline-based progress measure that quantifies how densely nonlinear regions partition the input space, independent of labels or loss. Training dynamics reveal three phases (descent, ascent, region migration) where nonlinear regions move toward the decision boundary, forming a robust partition that enables generalization and robustness long after interpolation. The findings connect region migration to grokking, demonstrate delayed robustness to adversarial examples, and show that factors like architecture, activation, and Batch Normalization critically shape these dynamics. Together, the results provide a geometric, mechanistic lens on why and when networks learn to generalize and become robust, with implications for training regimes and model design.”

Abstract

Grokking, or delayed generalization, is a phenomenon where generalization in a deep neural network (DNN) occurs long after achieving near zero training error. Previous studies have reported the occurrence of grokking in specific controlled settings, such as DNNs initialized with large-norm parameters or transformers trained on algorithmic datasets. We demonstrate that grokking is actually much more widespread and materializes in a wide range of practical settings, such as training of a convolutional neural network (CNN) on CIFAR10 or a Resnet on Imagenette. We introduce the new concept of delayed robustness, whereby a DNN groks adversarial examples and becomes robust, long after interpolation and/or generalization. We develop an analytical explanation for the emergence of both delayed generalization and delayed robustness based on the local complexity of a DNN's input-output mapping. Our local complexity measures the density of so-called linear regions (aka, spline partition regions) that tile the DNN input space and serves as a utile progress measure for training. We provide the first evidence that, for classification problems, the linear regions undergo a phase transition during training whereafter they migrate away from the training samples (making the DNN mapping smoother there) and towards the decision boundary (making the DNN mapping less smooth there). Grokking occurs post phase transition as a robust partition of the input space thanks to the linearization of the DNN mapping around the training points. Website: https://bit.ly/grok-adversarial

Deep Networks Always Grok and Here is Why

TL;DR

This work shows that grokking—delayed generalization—occurs broadly across practical DNNs, not just in contrived setups. It introduces Local Complexity (LC), a spline-based progress measure that quantifies how densely nonlinear regions partition the input space, independent of labels or loss. Training dynamics reveal three phases (descent, ascent, region migration) where nonlinear regions move toward the decision boundary, forming a robust partition that enables generalization and robustness long after interpolation. The findings connect region migration to grokking, demonstrate delayed robustness to adversarial examples, and show that factors like architecture, activation, and Batch Normalization critically shape these dynamics. Together, the results provide a geometric, mechanistic lens on why and when networks learn to generalize and become robust, with implications for training regimes and model design.”

Abstract

Grokking, or delayed generalization, is a phenomenon where generalization in a deep neural network (DNN) occurs long after achieving near zero training error. Previous studies have reported the occurrence of grokking in specific controlled settings, such as DNNs initialized with large-norm parameters or transformers trained on algorithmic datasets. We demonstrate that grokking is actually much more widespread and materializes in a wide range of practical settings, such as training of a convolutional neural network (CNN) on CIFAR10 or a Resnet on Imagenette. We introduce the new concept of delayed robustness, whereby a DNN groks adversarial examples and becomes robust, long after interpolation and/or generalization. We develop an analytical explanation for the emergence of both delayed generalization and delayed robustness based on the local complexity of a DNN's input-output mapping. Our local complexity measures the density of so-called linear regions (aka, spline partition regions) that tile the DNN input space and serves as a utile progress measure for training. We provide the first evidence that, for classification problems, the linear regions undergo a phase transition during training whereafter they migrate away from the training samples (making the DNN mapping smoother there) and towards the decision boundary (making the DNN mapping less smooth there). Grokking occurs post phase transition as a robust partition of the input space thanks to the linearization of the DNN mapping around the training points. Website: https://bit.ly/grok-adversarial
Paper Structure (18 sections, 10 equations, 36 figures, 1 table)

This paper contains 18 sections, 10 equations, 36 figures, 1 table.

Figures (36)

  • Figure 1: Deep Neural Networks grok robustness. When training a ResNet18 on CIFAR10, without any controlled initialization as in liu2022omnigrok, the network starts grokking adversarial examples generated using Projected Gradient Descent madry2017towards after $10^4$ optimization steps (top) and attains almost equal robustness and generalization performance after $2\times10^5$ steps. We see that, prior to grokking, the network undergoes a phase change during training in the local complexity, i.e., the local density of spline partition regions in the input space (bottom). After test accuracy converges, the network starts migrating its non-linearities away from the data points and closer to the decision boundary (see \ref{['fig:mnist-splinecam']}), eventually reducing the complexity of the learned function around the data points. This increase and subsequent decrease in local non-linearity is a phenomenon visible for a wide variety of networks and training settings (see \ref{['fig:resnet-cifar10-imagenette']}). In this paper, we show that this particular training dynamic always results in delayed generalization or robustness.
  • Figure 2: Emergence of Robust Partition. We train a 4-layer ReLU Multi Layer Perceptron (MLP) of $200$ width, on $1K$ samples from MNIST for $10^5$ optimization steps, with batch size $200$. We see that the network starts grokking adversarial examples after approximately $10^4$ optimization steps (top-left). The local complexity around data points (bottom-left) follows a double descent curve with the final descent starting approximately after $10^4$ optimization steps as well. Where do the non-linearities migrate to? In the middle and right images we present analytically computed visualizations of the DNN input space partition Humayun_2023_CVPR. The partition or linear regions are visualized across a 2D domain in the input space, that intersects three training samples. We see that during the final descent in local complexity, a unique structure emerges in the DNN partition geometry, where a large number of non-linearities (black lines) therefore linear regions, have concentrated around the decision boundary (red line). We dub this phenomenon Region Migration. Animation for an entire training run in https://bit.ly/grok-splinecam.
  • Figure 3: Curvature and complexity. Visual depiction of \ref{['eq:CPA']} with a toy affine spline $S : \mathbb{R}^2 \rightarrow \mathbb{R}$, obtained by training an MLP to regress the piecewise function $f(x_1,x_2) = \{\sin(x_1)+\cos(x_2)\}\mathbbm{1}_{x_1<0}$. Regions in the input space partition $\Omega$ (left) and the graph of the affine spline function (right) are randomly colored. The spline partition has significantly higher density of non-linearities for $x_1<0$, i.e., the local complexity is higher where the learned function has more curvature.
  • Figure 4: Local Complexity Approximation. 1) Given a point in the input space $x\in \mathbb{R}^D$, we start by sampling $P$ orthonormal vectors $\{v_1,v_2,...,v_P\}$ to obtain cross-polytopal frame ${\bm{V}}_x=\{x \pm r*v_p \forall p\}$ centered on $x$, where $r$ is a radius parameter. We consider the convex hull $conv({\bm{V}}_x)$ as the local neighborhood of $x$. 2) If any neuron hyperplane intersects the neighborhood $conv({\bm{V}}_x)$ then the pre-activation sign will be different for the different vertices. We can therefore count the number neurons for a given layer, which results in sign changes in the pre-activation of ${\bm{V}}_x$ to quantify local complexity $x$ for that layer. 3) By embedding ${\bm{V}}_x$ to the input of the next layer, we can obtain a coarse approximation of the local neighborhood of $x$ and continue computing local complexity in a layerwise fashion.
  • Figure 5: Deformation with depth. Change of average eccentricity xu2021comparing of the input space neighborhoods ${\bm{V}}_x$ by different layers of a CNN trained on the CIFAR10 dataset, for different radius $r$. We see that, for larger radius, the deformation increases with depth almost exponentially. For $r\leq 0.014$ deformation is low, indicating that smaller radius neighborhoods are reliable for LC computation on deeper networks. Values are averaged over neighborhoods sampled for $1000$ training points from CIFAR10. For ResNet18, see \ref{['fig:appendix_deformation_resnet']}.
  • ...and 31 more figures