Table of Contents
Fetching ...

Explaining Grokking and Information Bottleneck through Neural Collapse Emergence

Keitaro Sakamoto, Issei Sato

TL;DR

The paper investigates two late-phase phenomena in deep learning—grokking and information bottleneck dynamics—and proposes neural collapse as a unifying mechanism. It introduces the population within-class variance, relates it to NC1, and shows that its contraction governs both delayed generalization and IB compression, with distinct time scales for neural collapse and training loss. The authors provide theoretical bounds linking variance reduction to generalization and information-theoretic quantities, and they validate the theory through extensive experiments across MNIST, Fashion-MNIST, CIFAR-10, CNNs, ResNets, transformers, and text benchmarks. Practically, the work suggests monitoring within-class variance (RNC1) and using weight decay to modulate the onset of late-phase improvements, offering a cohesive framework for understanding and guiding late-stage training dynamics.

Abstract

The training dynamics of deep neural networks often defy expectations, even as these models form the foundation of modern machine learning. Two prominent examples are grokking, where test performance improves abruptly long after the training loss has plateaued, and the information bottleneck principle, where models progressively discard input information irrelevant to the prediction task as training proceeds. However, the mechanisms underlying these phenomena and their relations remain poorly understood. In this work, we present a unified explanation of such late-phase phenomena through the lens of neural collapse, which characterizes the geometry of learned representations. We show that the contraction of population within-class variance is a key factor underlying both grokking and information bottleneck, and relate this measure to the neural collapse measure defined on the training set. By analyzing the dynamics of neural collapse, we show that distinct time scales between fitting the training set and the progression of neural collapse account for the behavior of the late-phase phenomena. Finally, we validate our theoretical findings on multiple datasets and architectures.

Explaining Grokking and Information Bottleneck through Neural Collapse Emergence

TL;DR

The paper investigates two late-phase phenomena in deep learning—grokking and information bottleneck dynamics—and proposes neural collapse as a unifying mechanism. It introduces the population within-class variance, relates it to NC1, and shows that its contraction governs both delayed generalization and IB compression, with distinct time scales for neural collapse and training loss. The authors provide theoretical bounds linking variance reduction to generalization and information-theoretic quantities, and they validate the theory through extensive experiments across MNIST, Fashion-MNIST, CIFAR-10, CNNs, ResNets, transformers, and text benchmarks. Practically, the work suggests monitoring within-class variance (RNC1) and using weight decay to modulate the onset of late-phase improvements, offering a cohesive framework for understanding and guiding late-stage training dynamics.

Abstract

The training dynamics of deep neural networks often defy expectations, even as these models form the foundation of modern machine learning. Two prominent examples are grokking, where test performance improves abruptly long after the training loss has plateaued, and the information bottleneck principle, where models progressively discard input information irrelevant to the prediction task as training proceeds. However, the mechanisms underlying these phenomena and their relations remain poorly understood. In this work, we present a unified explanation of such late-phase phenomena through the lens of neural collapse, which characterizes the geometry of learned representations. We show that the contraction of population within-class variance is a key factor underlying both grokking and information bottleneck, and relate this measure to the neural collapse measure defined on the training set. By analyzing the dynamics of neural collapse, we show that distinct time scales between fitting the training set and the progression of neural collapse account for the behavior of the late-phase phenomena. Finally, we validate our theoretical findings on multiple datasets and architectures.

Paper Structure

This paper contains 42 sections, 11 theorems, 52 equations, 14 figures.

Key Result

Theorem 3.2

For a fixed feature extractor $g$ and the last layer ${\bm{W}} = ({\bm{w}}_1, \ldots, {\bm{w}}_K)^\top$, we have

Figures (14)

  • Figure 1: Conceptual relationships in the late-phase training discussed in this work.
  • Figure 2: Margins of individual examples at two time steps during grokking. The margin of each example is defined as the signed distance from its representation to the decision boundary determined by the last-layer classifier, calculated as $\left( \langle {\bm{w}}_0 - {\bm{w}}_1, g({\bm{x}}) \rangle + b_0 - b_1 \right) / \| {\bm{w}}_0 - {\bm{w}}_1 \|_2$, where $b_c$ denotes the bias term for class $c$. We trained a 4-layer MLP on the MNIST dataset. These results reveal the link between representation variance and generalization, supporting \ref{['thm:grokking', 'thm:within_class_variance_and_neural_collapse']}. We additionally provide similar visualizations for several other class pairs in \ref{['sec:additional_experiments_grokking']}.
  • Figure 3: Dynamics of test accuracy, RNC1, and NC2 scores throughout training for different weight decay coefficients $\lambda$. In the test accuracy panel (left), the training accuracy is additionally shown in dashed lines of the same color to visualize grokking behavior. Results are averaged over five different seeds with an MLP trained on the MNIST dataset. These results demonstrate the connection between neural collapse and grokking, and their time scales, supporting \ref{['thm:grokking', 'thm:within_class_variance_and_neural_collapse', 'thm:nc1_dynamics']}.
  • Figure 4: Dynamics of redundant information in IB framework (estimated via MI and nHSIC) and RNC1 scores throughout training for different weight decay $\lambda$. Results are averaged over five different seeds with an MLP trained on the MNIST dataset. These results show the connection between neural collapse and IB dynamics, as well as their time scales, supporting \ref{['thm:ib_upper_bound', 'thm:within_class_variance_and_neural_collapse', 'thm:nc1_dynamics']}.
  • Figure 5: Margins of individual examples at two time steps during grokking. The model is a 4-layer MLP on the MNIST dataset. Here, we provide the corresponding plots for several class pairs other than the 0-1 pair shown in \ref{['fig:grokking_boundary']}.
  • ...and 9 more figures

Theorems & Definitions (23)

  • Definition 3.1: Population within-class variance
  • Theorem 3.2: Generalization via population within-class variance
  • Proposition 3.3: Phase 1 of IB dynamics
  • Theorem 3.4: Phase 2 of IB dynamics via population within-class variance
  • Theorem 4.1: Concentration of within-class variance
  • Remark 4.2: Difference between RNC1 and NC1
  • Theorem 4.3: Time scales of neural collapse dynamics
  • Remark 4.4: Summary of Theoretical Results
  • Lemma B.1: Cantelli's inequality
  • proof : Proof of \ref{['thm:grokking']}
  • ...and 13 more