Table of Contents
Fetching ...

Flatness is Necessary, Neural Collapse is Not: Rethinking Generalization via Grokking

Ting Han, Linara Adilova, Henning Petzka, Jens Kleesiek, Michael Kamp

TL;DR

The paper investigates whether neural collapse (NC) or loss landscape flatness causally underpins generalization. Using grokking to temporally separate memorization from generalization, it shows NC can emerge without being necessary for generalization, while relative flatness consistently aligns with when generalization appears. The authors also demonstrate that actively increasing flatness delays generalization (grokking-like behavior) across diverse architectures and tasks, and that NC can, in fact, be suppressed without harming generalization. Theoretically, NC implies relative flatness under classical assumptions, tying these phenomena together, while representativeness of learned features remains essential for generalization. Overall, the work positions relative flatness as a more fundamental driver of generalization than NC and suggests grokking as a powerful probe into the geometry of learning.

Abstract

Neural collapse, i.e., the emergence of highly symmetric, class-wise clustered representations, is frequently observed in deep networks and is often assumed to reflect or enable generalization. In parallel, flatness of the loss landscape has been theoretically and empirically linked to generalization. Yet, the causal role of either phenomenon remains unclear: Are they prerequisites for generalization, or merely by-products of training dynamics? We disentangle these questions using grokking, a training regime in which memorization precedes generalization, allowing us to temporally separate generalization from training dynamics and we find that while both neural collapse and relative flatness emerge near the onset of generalization, only flatness consistently predicts it. Models encouraged to collapse or prevented from collapsing generalize equally well, whereas models regularized away from flat solutions exhibit delayed generalization, resembling grokking, even in architectures and datasets where it does not typically occur. Furthermore, we show theoretically that neural collapse leads to relative flatness under classical assumptions, explaining their empirical co-occurrence. Our results support the view that relative flatness is a potentially necessary and more fundamental property for generalization, and demonstrate how grokking can serve as a powerful probe for isolating its geometric underpinnings.

Flatness is Necessary, Neural Collapse is Not: Rethinking Generalization via Grokking

TL;DR

The paper investigates whether neural collapse (NC) or loss landscape flatness causally underpins generalization. Using grokking to temporally separate memorization from generalization, it shows NC can emerge without being necessary for generalization, while relative flatness consistently aligns with when generalization appears. The authors also demonstrate that actively increasing flatness delays generalization (grokking-like behavior) across diverse architectures and tasks, and that NC can, in fact, be suppressed without harming generalization. Theoretically, NC implies relative flatness under classical assumptions, tying these phenomena together, while representativeness of learned features remains essential for generalization. Overall, the work positions relative flatness as a more fundamental driver of generalization than NC and suggests grokking as a powerful probe into the geometry of learning.

Abstract

Neural collapse, i.e., the emergence of highly symmetric, class-wise clustered representations, is frequently observed in deep networks and is often assumed to reflect or enable generalization. In parallel, flatness of the loss landscape has been theoretically and empirically linked to generalization. Yet, the causal role of either phenomenon remains unclear: Are they prerequisites for generalization, or merely by-products of training dynamics? We disentangle these questions using grokking, a training regime in which memorization precedes generalization, allowing us to temporally separate generalization from training dynamics and we find that while both neural collapse and relative flatness emerge near the onset of generalization, only flatness consistently predicts it. Models encouraged to collapse or prevented from collapsing generalize equally well, whereas models regularized away from flat solutions exhibit delayed generalization, resembling grokking, even in architectures and datasets where it does not typically occur. Furthermore, we show theoretically that neural collapse leads to relative flatness under classical assumptions, explaining their empirical co-occurrence. Our results support the view that relative flatness is a potentially necessary and more fundamental property for generalization, and demonstrate how grokking can serve as a powerful probe for isolating its geometric underpinnings.

Paper Structure

This paper contains 30 sections, 4 theorems, 36 equations, 20 figures, 1 table.

Key Result

Proposition 5.3

Let $f(x) = \text{softmax}(w \phi(x)+b)$ be a neural network with softmax output and trained with cross-entropy loss, where $w \in \mathbb{R}^{k \times d}$ denotes the final-layer weight matrix classifying into $k$ classes and $\phi(x)$ is the penultimate-layer representation. Assume that the classi In particular, for sufficiently large $\lambda$, this yields the asymptotic bound: which decays ex

Figures (20)

  • Figure 1: Neural collapse clustering and relative flatness in grokking. While both correlate with generalization, neural collapse emerges early during memorization, whereas flatness only drops sharply when generalization begins, highlighting flatness as a better indicator of generalization onset.
  • Figure 2: Results of neural collapse clustering regularization on CIFAR-10. We display both unregularized and regularized (REG) training dynamics for comparison. Increasing NCC does not affect generalization or relative flatness, indicating that NC is not necessary for generalization. Figure (a) shows the training and validation accuracies, and y-axis represents accuracy. Figure (b) presents the relative flatness values during training, and y-axis represents measurement of relative flatness. Figure (c) illustrates the NCC development, and y-axis represents NCC value.
  • Figure 3: Pairwise cluster angles vs. optimal 10-simplex angles. Under NCC regularization, angles drift from the optimal configuration, while in standard training they remain stable. "REG" indicates use of the NCC regularizer.
  • Figure 4: Results of inducing delayed generalization on training Resnet18 on CIFAR-10 and ViT on ImageNet-100 through relative flatness regularization. We display both unregularized and regularized (unplug) training dynamics for comparison. Relative flatness regularizer is removed at epoch 200 for CIFAR-10 experiment and 150 for ImageNet-100. Delayed generalization occurs after the regularizer is removed, as indicated by sharp increase in validation accuracy after a drop. Figure (a) shows the training and validation accuracies on CIFAR-10, and Figure (b) on ImageNet-100.
  • Figure 5: Results of training and validation losses in the NCC experiments.
  • ...and 15 more figures

Theorems & Definitions (10)

  • Definition 3.1: NCC Measure
  • Definition 3.2: Relative Flatness
  • Remark 5.2
  • Proposition 5.3
  • Lemma A.1
  • proof
  • Lemma A.2
  • proof
  • Proposition A.4
  • proof