Table of Contents
Fetching ...

Neural Collapse Under MSE Loss: Proximity to and Dynamics on the Central Path

X. Y. Han, Vardan Papyan, David L. Donoho

TL;DR

This paper shows that Neural Collapse (NC) is not limited to cross-entropy loss by establishing NC phenomena under mean squared error (MSE) loss. It introduces a decomposed MSE loss, identifies a central path where the last-layer classifier acts as the least-squares classifier, and proves invariance properties with renormalized features. By analyzing continually renormalized gradient flow on the renormalized feature manifold, it derives exact dynamics for the Signal-to-Noise Ratio (SNR) matrix that drive NC, including a closed-form relation c1 log(ωj) + c2 ωj^2 + c3 ωj^4 = a_j + t for each nonzero singular value ωj. The results show that as training progresses, the SNR singular values diverge and equalize, causing the class-means and classifiers to align with a Simplex ETF and leading to NC (NC1–NC4), thereby providing a rigorous, tractable theory for NC under MSE and offering insights into training dynamics beyond CE loss.

Abstract

The recently discovered Neural Collapse (NC) phenomenon occurs pervasively in today's deep net training paradigm of driving cross-entropy (CE) loss towards zero. During NC, last-layer features collapse to their class-means, both classifiers and class-means collapse to the same Simplex Equiangular Tight Frame, and classifier behavior collapses to the nearest-class-mean decision rule. Recent works demonstrated that deep nets trained with mean squared error (MSE) loss perform comparably to those trained with CE. As a preliminary, we empirically establish that NC emerges in such MSE-trained deep nets as well through experiments on three canonical networks and five benchmark datasets. We provide, in a Google Colab notebook, PyTorch code for reproducing MSE-NC and CE-NC: at https://colab.research.google.com/github/neuralcollapse/neuralcollapse/blob/main/neuralcollapse.ipynb. The analytically-tractable MSE loss offers more mathematical opportunities than the hard-to-analyze CE loss, inspiring us to leverage MSE loss towards the theoretical investigation of NC. We develop three main contributions: (I) We show a new decomposition of the MSE loss into (A) terms directly interpretable through the lens of NC and which assume the last-layer classifier is exactly the least-squares classifier; and (B) a term capturing the deviation from this least-squares classifier. (II) We exhibit experiments on canonical datasets and networks demonstrating that term-(B) is negligible during training. This motivates us to introduce a new theoretical construct: the central path, where the linear classifier stays MSE-optimal for feature activations throughout the dynamics. (III) By studying renormalized gradient flow along the central path, we derive exact dynamics that predict NC.

Neural Collapse Under MSE Loss: Proximity to and Dynamics on the Central Path

TL;DR

This paper shows that Neural Collapse (NC) is not limited to cross-entropy loss by establishing NC phenomena under mean squared error (MSE) loss. It introduces a decomposed MSE loss, identifies a central path where the last-layer classifier acts as the least-squares classifier, and proves invariance properties with renormalized features. By analyzing continually renormalized gradient flow on the renormalized feature manifold, it derives exact dynamics for the Signal-to-Noise Ratio (SNR) matrix that drive NC, including a closed-form relation c1 log(ωj) + c2 ωj^2 + c3 ωj^4 = a_j + t for each nonzero singular value ωj. The results show that as training progresses, the SNR singular values diverge and equalize, causing the class-means and classifiers to align with a Simplex ETF and leading to NC (NC1–NC4), thereby providing a rigorous, tractable theory for NC under MSE and offering insights into training dynamics beyond CE loss.

Abstract

The recently discovered Neural Collapse (NC) phenomenon occurs pervasively in today's deep net training paradigm of driving cross-entropy (CE) loss towards zero. During NC, last-layer features collapse to their class-means, both classifiers and class-means collapse to the same Simplex Equiangular Tight Frame, and classifier behavior collapses to the nearest-class-mean decision rule. Recent works demonstrated that deep nets trained with mean squared error (MSE) loss perform comparably to those trained with CE. As a preliminary, we empirically establish that NC emerges in such MSE-trained deep nets as well through experiments on three canonical networks and five benchmark datasets. We provide, in a Google Colab notebook, PyTorch code for reproducing MSE-NC and CE-NC: at https://colab.research.google.com/github/neuralcollapse/neuralcollapse/blob/main/neuralcollapse.ipynb. The analytically-tractable MSE loss offers more mathematical opportunities than the hard-to-analyze CE loss, inspiring us to leverage MSE loss towards the theoretical investigation of NC. We develop three main contributions: (I) We show a new decomposition of the MSE loss into (A) terms directly interpretable through the lens of NC and which assume the last-layer classifier is exactly the least-squares classifier; and (B) a term capturing the deviation from this least-squares classifier. (II) We exhibit experiments on canonical datasets and networks demonstrating that term-(B) is negligible during training. This motivates us to introduce a new theoretical construct: the central path, where the linear classifier stays MSE-optimal for feature activations throughout the dynamics. (III) By studying renormalized gradient flow along the central path, we derive exact dynamics that predict NC.

Paper Structure

This paper contains 40 sections, 21 theorems, 118 equations, 14 figures, 1 table.

Key Result

Proposition 1

For fixed extended features $\widetilde{\boldsymbol{H}}$, the optimal classifier minimizing the MSE loss $\mathcal{L}(\widetilde{\boldsymbol{W}},\widetilde{\boldsymbol{H}})$ is where $\boldsymbol{I}$ is the identity matrix. Note that $\widetilde{\boldsymbol{W}}_{\!\text{LS}}$ depends on $\widetilde{\boldsymbol{H}}$ only.

Figures (14)

  • Figure 1: Portrait of Neural Collapse. Top figure depicts the last-layer features, class-means, and classifiers with which NC is defined---as well as the Simplex ETF to which they all converge with training. Bottom figure shows the deviations of features from their corresponding class-means. Reproduced and modified from Figure 1 of *papyan2020prevalence.
  • Figure 2: Decomposition of MSE loss: Each array column shows a benchmark image classification dataset while each row shows a canonical deep net architecture trained with MSE loss. The red vertical line indicates the epoch at which zero training error was achieved. In each array cell, we plot terms of the MSE loss decomposition $\mathcal{L}(\widetilde{\boldsymbol{W}},\widetilde{\boldsymbol{H}}) = \mathcal{L}_{\text{NC1}}(\widetilde{\boldsymbol{H}}) + \mathcal{L}_{\text{NC2/3}}(\widetilde{\boldsymbol{H}}) + \mathcal{L}_{\text{LS}}^\perp(\widetilde{\boldsymbol{W}},\widetilde{\boldsymbol{H}})$ from Section \ref{['sec:decomp']}. Starting from an early epoch in training, $\mathcal{L}_{\text{LS}}^\perp(\widetilde{\boldsymbol{W}},\widetilde{\boldsymbol{H}})$ becomes negligible compared to the dominant term, $\mathcal{L}_{\text{NC1}}(\widetilde{\boldsymbol{H}})$, implying $\mathcal{L}_{\text{LS}}^\perp(\widetilde{\boldsymbol{W}},\widetilde{\boldsymbol{H}}){\ll}\mathcal{L}_{\text{LS}}(\widetilde{\boldsymbol{H}}){=}\mathcal{L}_{\text{NC1}}(\widetilde{\boldsymbol{H}}){+}\mathcal{L}_{\text{NC2/3}}(\widetilde{\boldsymbol{H}})$, i.e. the features and classifiers are effectively on the central path during TPT. Note that $\mathcal{L}_{\text{NC2/3}}(\widetilde{\boldsymbol{H}})$ diminishes the fastest among all the terms: Intuitively, this shows that the network primarily focuses on distributing the feature class-means into a "uniform" Simplex ETF configuration (NC1)-(NC2) early on and, from there, compresses the activations towards their class-means, i.e. (NC1) , as much as possible. Further experimental details are in Appendix \ref{['sec:NC_experiments']}. Outlier behavior is discussed in Appendix \ref{['sec:STL']}.
  • Figure 3: Plots analogous to Figure 2 in papyan2020prevalence, but on networks trained with MSE Loss. Results demonstrate that last-layer features and classifiers approach equinormness.
  • Figure 4: Plots analogous to Figure 3 in papyan2020prevalence, but on networks trained with MSE Loss. Results demonstrate that last-layer features and classifiers approach equiangularity.
  • Figure 5: Plots analogous to Figure 4 in papyan2020prevalence, but on networks trained with MSE Loss. Results demonstrate that last-layer features and classifiers approach maximal-equiangularity.
  • ...and 9 more figures

Theorems & Definitions (39)

  • Proposition 1: webb1990optimised with Weight Decay
  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Definition 1: Zero Global Mean Central Path
  • Definition 2: SVD of SNR Matrix
  • Proposition 2: Dynamics of Singular Values of SNR Matrix; Proof in Appendix \ref{['sec:proof_snr_flow']}
  • Corollary 1: Properties of SNR Singular Values; Proof in Appendix \ref{['sec:proof_ode1']}
  • Corollary 2: Neural Collapse Under MSE Loss; Proof in Appendix \ref{['sec:proof_ode2']}
  • Theorem 3
  • ...and 29 more