Table of Contents
Fetching ...

Towards Understanding Neural Collapse: The Effects of Batch Normalization and Weight Decay

Leyan Pan, Xinyuan Cao

TL;DR

The paper analyzes Neural Collapse (NC) — a geometric arrangement where last-layer features within a class collapse and different classes spread as a simplex ETF — through the lens of last-layer Batch Normalization (BN) and Weight Decay (WD). By framing the problem in a layer-peeled model with near-optimal cross-entropy loss, it derives explicit bounds on intra-class and inter-class cosine similarities that quantify NC proximity, demonstrating that BN and sufficiently large WD strengthen NC guarantees. Theoretical results are complemented by extensive experiments on synthetic and real datasets (MNIST, CIFAR-10/100, ImageNet32), which show that BN plus higher WD values yield stronger NC proximity, especially as training loss decreases and last-layer feature norms shrink. Overall, the work provides a new optimization-agnostic perspective on how BN and WD shape feature geometry, with implications for understanding generalization and the role of normalization in deep networks.

Abstract

Neural Collapse (NC) is a geometric structure recently observed at the terminal phase of training deep neural networks, which states that last-layer feature vectors for the same class would "collapse" to a single point, while features of different classes become equally separated. We demonstrate that batch normalization (BN) and weight decay (WD) critically influence the emergence of NC. In the near-optimal loss regime, we establish an asymptotic lower bound on the emergence of NC that depends only on the WD value, training loss, and the presence of last-layer BN. Our experiments substantiate theoretical insights by showing that models demonstrate a stronger presence of NC with BN, appropriate WD values, lower loss, and lower last-layer feature norm. Our findings offer a novel perspective in studying the role of BN and WD in shaping neural network features.

Towards Understanding Neural Collapse: The Effects of Batch Normalization and Weight Decay

TL;DR

The paper analyzes Neural Collapse (NC) — a geometric arrangement where last-layer features within a class collapse and different classes spread as a simplex ETF — through the lens of last-layer Batch Normalization (BN) and Weight Decay (WD). By framing the problem in a layer-peeled model with near-optimal cross-entropy loss, it derives explicit bounds on intra-class and inter-class cosine similarities that quantify NC proximity, demonstrating that BN and sufficiently large WD strengthen NC guarantees. Theoretical results are complemented by extensive experiments on synthetic and real datasets (MNIST, CIFAR-10/100, ImageNet32), which show that BN plus higher WD values yield stronger NC proximity, especially as training loss decreases and last-layer feature norms shrink. Overall, the work provides a new optimization-agnostic perspective on how BN and WD shape feature geometry, with implications for understanding generalization and the role of normalization in deep networks.

Abstract

Neural Collapse (NC) is a geometric structure recently observed at the terminal phase of training deep neural networks, which states that last-layer feature vectors for the same class would "collapse" to a single point, while features of different classes become equally separated. We demonstrate that batch normalization (BN) and weight decay (WD) critically influence the emergence of NC. In the near-optimal loss regime, we establish an asymptotic lower bound on the emergence of NC that depends only on the WD value, training loss, and the presence of last-layer BN. Our experiments substantiate theoretical insights by showing that models demonstrate a stronger presence of NC with BN, appropriate WD values, lower loss, and lower last-layer feature norm. Our findings offer a novel perspective in studying the role of BN and WD in shaping neural network features.
Paper Structure (31 sections, 18 theorems, 135 equations, 9 figures, 1 table)

This paper contains 31 sections, 18 theorems, 135 equations, 9 figures, 1 table.

Key Result

Theorem 1.1

For the layer-peeled classification model of $C$ classes with weight decay parameter $\lambda$ and cross-entropy training loss within $\epsilon$ of the optimal loss, the following holds for most classes/pairs of classes:

Figures (9)

  • Figure 1: Visualization of $\mathcal{NC}$(doi:10.1073/pnas.2015509117). We use an example of three classes and denote the last-layer features $\mathbf{h}_{c,i}$, mean class features $\tilde{\mathbf{h}}_{c}$, and last-layer class weight vectors $\mathbf{w}_{c,i}$. Circles denote individual last-layer features, while compound and filled arrows denote class weight and mean feature vectors, respectively. As training progresses, the last-layer features of each class collapse to their corresponding class means (NC1), different class means converge to the vertices of the simplex ETF (NC2), and the class weight vector of the last-layer linear classifier approaches the corresponding class means (NC3).
  • Figure 2: $\mathcal{NC}$ increases with WD under BN: Minimum intra-class and maximum inter-class Cosine Similarity for 4-layer and 6-layer MLP under Different WD and BN on the synthetic dataset generated using a randomly initialized 3-layer MLP. Higher values of intra-class and lower values of inter-class cosine similarity imply a higher degree of Neural Collapse. The green and yellow lines are cosine similarity measures for the model with BN, which demonstrates stronger $\mathcal{NC}$ along with higher WD values. Standard deviation over 5 experiments.
  • Figure 3: $\mathcal{NC}$ closely represents loss value under BN: Relationship between $\mathcal{NC}$ and training loss during the training process. The purple dashed line is the training loss presented in the log scale with axis labels on the right. The models with Batch Normalization (plots 1 and 3) demonstrate more correlation between loss value and $\mathcal{NC}$ during training.
  • Figure 4: $\mathcal{NC}$ correlates with feature norm: Min intra-class and max inter-class Cosine Similarity for synthetic dataset and MLP models with BN under different $|\boldsymbol{\gamma}|$ values. Higher intra-class and lower inter-class cosine similarity indicate a higher degree of $\mathcal{NC}$. Note that the intra-class and inter-class cosine similarity are split into two plots to display more detailed changes. Except for the 6-layer MLP trained on the conic hull dataset, all settings demonstrate a negative correlation between proximity to $\mathcal{NC}$ and the last-layer feature norm value as constrained by $|\boldsymbol{\gamma}|$. Standard Deviation over 3 experiments.
  • Figure 5: Intra-class and Inter-class Cosine Similarity for VGG11 and VGG19 and datasets CIFAR10 and CIFAR 100 under Different WD and BN combinations. Higher intra-class and lower inter-class cosine similarity indicate a higher degree of $\mathcal{NC}$. Both the average measures over all classes and the worst class are presented. The green and red lines are cosine similarity measures for the model with BN. In most cases, the models with BN demonstrates observably better $\mathcal{NC}$ than non-BN models, and the $\mathcal{NC}$ is more evident in models trained with larger WD value.
  • ...and 4 more figures

Theorems & Definitions (28)

  • Theorem 1.1: Informal version of Theorem \ref{['thm:main_cor']}
  • Theorem 2.1: $\mathcal{NC}$ proximity guarantee with bounded norms
  • Lemma 2.1: Subset mean close to global mean by Jensen's inequality on strongly convex functions
  • Theorem 2.2: Formal Version of Theorem \ref{['thm:informal_main']}
  • Lemma C.1: Restatement of Lemma \ref{['strongjensensub']}
  • proof
  • Lemma C.2: Theorem 4 from Merentes2010
  • Corollary C.1
  • Lemma C.3
  • proof
  • ...and 18 more