Table of Contents
Fetching ...

An Unconstrained Layer-Peeled Perspective on Neural Collapse

Wenlong Ji, Yiping Lu, Yiliang Zhang, Zhun Deng, Weijie J. Su

TL;DR

This work introduces the unconstrained layer-peeled model (ULPM) to study neural collapse in the last layer of classifiers, showing that gradient flow on ${\boldsymbol{W}}$ and ${\boldsymbol{H}}$ converges to a direction that is a Karush-Kuhn-Tucker (KKT) point of a minimum-norm separation problem, with the global optimum enforcing neural collapse (NC1–NC4). It proves that cross-entropy loss yields a benign global landscape where all non-global stationary points are strict saddles, ensuring convergence to neural-collapse solutions despite nonconvexity. Empirically, the authors verify neural collapse in unregularized training on datasets like MNIST and CIFAR-10 across various architectures, indicating that implicit regularization from gradient descent and the loss function suffices to drive the phenomenon in practice. The results illuminate the role of implicit bias and gradient dynamics in achieving robust generalization and stability, suggesting that explicit feature-norm constraints are not essential for neural-collapse behavior.

Abstract

Neural collapse is a highly symmetric geometric pattern of neural networks that emerges during the terminal phase of training, with profound implications on the generalization performance and robustness of the trained networks. To understand how the last-layer features and classifiers exhibit this recently discovered implicit bias, in this paper, we introduce a surrogate model called the unconstrained layer-peeled model (ULPM). We prove that gradient flow on this model converges to critical points of a minimum-norm separation problem exhibiting neural collapse in its global minimizer. Moreover, we show that the ULPM with the cross-entropy loss has a benign global landscape for its loss function, which allows us to prove that all the critical points are strict saddle points except the global minimizers that exhibit the neural collapse phenomenon. Empirically, we show that our results also hold during the training of neural networks in real-world tasks when explicit regularization or weight decay is not used.

An Unconstrained Layer-Peeled Perspective on Neural Collapse

TL;DR

This work introduces the unconstrained layer-peeled model (ULPM) to study neural collapse in the last layer of classifiers, showing that gradient flow on and converges to a direction that is a Karush-Kuhn-Tucker (KKT) point of a minimum-norm separation problem, with the global optimum enforcing neural collapse (NC1–NC4). It proves that cross-entropy loss yields a benign global landscape where all non-global stationary points are strict saddles, ensuring convergence to neural-collapse solutions despite nonconvexity. Empirically, the authors verify neural collapse in unregularized training on datasets like MNIST and CIFAR-10 across various architectures, indicating that implicit regularization from gradient descent and the loss function suffices to drive the phenomenon in practice. The results illuminate the role of implicit bias and gradient dynamics in achieving robust generalization and stability, suggesting that explicit feature-norm constraints are not essential for neural-collapse behavior.

Abstract

Neural collapse is a highly symmetric geometric pattern of neural networks that emerges during the terminal phase of training, with profound implications on the generalization performance and robustness of the trained networks. To understand how the last-layer features and classifiers exhibit this recently discovered implicit bias, in this paper, we introduce a surrogate model called the unconstrained layer-peeled model (ULPM). We prove that gradient flow on this model converges to critical points of a minimum-norm separation problem exhibiting neural collapse in its global minimizer. Moreover, we show that the ULPM with the cross-entropy loss has a benign global landscape for its loss function, which allows us to prove that all the critical points are strict saddle points except the global minimizers that exhibit the neural collapse phenomenon. Empirically, we show that our results also hold during the training of neural networks in real-world tasks when explicit regularization or weight decay is not used.

Paper Structure

This paper contains 23 sections, 11 theorems, 121 equations, 8 figures, 1 table.

Key Result

Theorem 3.1

For the ULPM model (P1), the margin of the entire dataset always satisfies and the equality holds if and only if $({\boldsymbol{W}},{\boldsymbol{H}})$ satisfies the neural collapse conditions with $\|W\|_F=\|H\|_F$.

Figures (8)

  • Figure 1: Experiments on real datasets without weight decay. We trained a ResNet18 on both MNIST and CIFAR10 datasets. The $x$-axis in the figures are set to have $\log(\log(t))$ scales and the $y$-axis in the figures are set to have $\log$ scales.
  • Figure 2: Training dynamics in ULPM. The $x$-axis in the figures is set to have $\log(\log(t))$ scales, and the $y$-axis in the figures are set to have $\log$ scales. (a) The dynamics of the variation of the centered class-mean features' norms (shown in blue) and the variation of the classifier's norms (shown in red). We observe that the logarithm of both terms decreases at a rate $O(1/(\log(t)))$. (b) Dynamics of within-class variation of the last layer features. The logarithm of the variation converges at approximately the rate $O(1/\log(t)))$. (c) The dynamics of the cosines between pairs of last layer features (shown in blue) and those of the classifiers (shown in red). The logarithm of both terms converge approximately at rate $O(1/\log(t)))$. (d) Dynamics of the distance between the normalized centered classifier and normalized last layer feature. The logarithm of the quantity converges at approximately the rate $O(1/\log(t)))$ to the point of self-duality.
  • Figure 3: Experiments on real datasets without weight decay. We trained a VGG13 on CIFAR10 dataset. The $x$-axis in the figures are set to have $\log(\log(t))$ scales and the $y$-axis in the figures are set to have $\log$ scales.
  • Figure 4: Experiments on real datasets without weight decay. We trained a VGG13 on the MNIST dataset. The $x$-axis in the figures are set to have $\log(\log(t))$ scales and the $y$-axis in the figures are set to have $\log$ scales.
  • Figure 5: Experiments on real datasets without weight decay. We trained a VGG18 on the KMNIST dataset. The $x$-axis in the figures are set to have $\log(\log(t))$ scales and the $y$-axis in the figures are set to have $\log$ scales.
  • ...and 3 more figures

Theorems & Definitions (32)

  • Definition 2.1: Simplex ETF
  • Theorem 3.1: Neural collapse as max-margin solution
  • Theorem 3.2
  • Remark 3.1
  • Remark 3.2
  • Corollary 3.1
  • Example 3.1: A Motivating Example
  • Theorem 3.3
  • Remark 3.3
  • Definition 3.1: tangent space
  • ...and 22 more