Table of Contents
Fetching ...

A Tale of Two Circuits: Grokking as Competition of Sparse and Dense Subnetworks

William Merrill, Nikolaos Tsilivis, Aman Shukla

TL;DR

The paper investigates grokking on a sparse parity task and shows that a sparse subnetwork emerges and dominates predictions after the grokking transition, while a dense subnetwork governs behavior beforehand. This transition is associated with targeted norm growth in a subset of neurons, leading to effective sparsification and a competitive dynamic between subnetworks. The findings connect sparsity and norm dynamics to generalization, suggesting a mechanism that could underlie emergent behaviors in larger models. The work has implications for understanding how targeted weight growth and subnetwork specialization enable robust generalization in overparameterized networks.

Abstract

Grokking is a phenomenon where a model trained on an algorithmic task first overfits but, then, after a large amount of additional training, undergoes a phase transition to generalize perfectly. We empirically study the internal structure of networks undergoing grokking on the sparse parity task, and find that the grokking phase transition corresponds to the emergence of a sparse subnetwork that dominates model predictions. On an optimization level, we find that this subnetwork arises when a small subset of neurons undergoes rapid norm growth, whereas the other neurons in the network decay slowly in norm. Thus, we suggest that the grokking phase transition can be understood to emerge from competition of two largely distinct subnetworks: a dense one that dominates before the transition and generalizes poorly, and a sparse one that dominates afterwards.

A Tale of Two Circuits: Grokking as Competition of Sparse and Dense Subnetworks

TL;DR

The paper investigates grokking on a sparse parity task and shows that a sparse subnetwork emerges and dominates predictions after the grokking transition, while a dense subnetwork governs behavior beforehand. This transition is associated with targeted norm growth in a subset of neurons, leading to effective sparsification and a competitive dynamic between subnetworks. The findings connect sparsity and norm dynamics to generalization, suggesting a mechanism that could underlie emergent behaviors in larger models. The work has implications for understanding how targeted weight growth and subnetwork specialization enable robust generalization in overparameterized networks.

Abstract

Grokking is a phenomenon where a model trained on an algorithmic task first overfits but, then, after a large amount of additional training, undergoes a phase transition to generalize perfectly. We empirically study the internal structure of networks undergoing grokking on the sparse parity task, and find that the grokking phase transition corresponds to the emergence of a sparse subnetwork that dominates model predictions. On an optimization level, we find that this subnetwork arises when a small subset of neurons undergoes rapid norm growth, whereas the other neurons in the network decay slowly in norm. Thus, we suggest that the grokking phase transition can be understood to emerge from competition of two largely distinct subnetworks: a dense one that dominates before the transition and generalizes poorly, and a sparse one that dominates afterwards.
Paper Structure (17 sections, 3 theorems, 4 equations, 9 figures)

This paper contains 17 sections, 3 theorems, 4 equations, 9 figures.

Key Result

Proposition 1

For any $n$, there exists a 1-layer ReLU net with $2^k$ neurons that computes $(n, k)$-parity.

Figures (9)

  • Figure 1: An illustration of the structure of a neural network during training in algorithmic tasks. Neural networks often exhibit a memorization phase that corresponds to a dense network, followed by the generalization phase where a sparse, largely disjoint to the prior one, model takes over.
  • Figure 2: Accuracy (left), Average Loss (middle) and Effective Sparsity (right) during training of an FC network on $(40, 3)$ parity. Generalization accuracy suddenly jumps from random chance to flawless prediction concurrent with sparsification of the model. Shaded areas show randomness over the training dataset sampling, model initialization, and stochasticity of SGD (5 random seeds).
  • Figure 3: Left: Average norm of different subnetworks during training. Right: Agreement between the predictions of a subnetwork and the full network on the test set. The generalizing subnetwork is the final sparse net, the complementary subnetwork is its complement, and the control subnetwork is a random network with the same size as the generalizing one.
  • Figure 4: Accuracy curves for addition (left) and division (right). For the addition operator, the dashed line represents the % of dataset that can be solved by commuting test points and then looking them up in the memorized training set. The generalization accuracy before grokking matches this level, suggested that the network has learned to generalize the commutative property of addition before it learns to generalize fully.
  • Figure 5: Weight norm of individual neurons during training. Left: Evolution of the dominant neurons during the memorization epoch (first time we hit $>$ 98% train accuracy) and final epoch (that corresponds to the generalizing subnetwork). Right: Weight norm over time for all neurons. Notice that most of them are driven to 0.
  • ...and 4 more figures

Theorems & Definitions (6)

  • Proposition 1
  • proof
  • Proposition 2
  • proof
  • Proposition 3
  • proof