Table of Contents
Fetching ...

Optimizer choice matters for the emergence of Neural Collapse

Jim Zhao, Tin Sum Cheng, Wojciech Masarczyk, Aurelien Lucchi

TL;DR

This work shows that neural collapse (NC) is not universal across optimizers: adaptive methods with decoupled weight decay (e.g., AdamW) can prevent NC, while SGD and Adam with coupled weight decay promote NC. The authors introduce NC0 as a practical diagnostic and prove theoretical results linking NC0 dynamics to optimizer settings, including exponential decay under SGD with weight decay and momentum. Through nearly 3,900 training runs on multiple architectures and datasets, they reveal a pronounced role for weight-decay coupling and momentum in shaping NC, and they demonstrate that coupling WD preserves the duality and symmetry underlying NC. The findings highlight how implicit biases of optimizers, particularly weight-decay coupling, influence the geometric organization of representations, with implications for understanding training dynamics and generalization in realistic deep networks.

Abstract

Neural Collapse (NC) refers to the emergence of highly symmetric geometric structures in the representations of deep neural networks during the terminal phase of training. Despite its prevalence, the theoretical understanding of NC remains limited. Existing analyses largely ignore the role of the optimizer, thereby suggesting that NC is universal across optimization methods. In this work, we challenge this assumption and demonstrate that the choice of optimizer plays a critical role in the emergence of NC. The phenomenon is typically quantified through NC metrics, which, however, are difficult to track and analyze theoretically. To overcome this limitation, we introduce a novel diagnostic metric, NC0, whose convergence to zero is a necessary condition for NC. Using NC0, we provide theoretical evidence that NC cannot emerge under decoupled weight decay in adaptive optimizers, as implemented in AdamW. Concretely, we prove that SGD, SignGD with coupled weight decay (a special case of Adam), and SignGD with decoupled weight decay (a special case of AdamW) exhibit qualitatively different NC0 dynamics. Also, we show the accelerating effect of momentum on NC (beyond convergence of train loss) when trained with SGD, being the first result concerning momentum in the context of NC. Finally, we conduct extensive empirical experiments consisting of 3,900 training runs across various datasets, architectures, optimizers, and hyperparameters, confirming our theoretical results. This work provides the first theoretical explanation for optimizer-dependent emergence of NC and highlights the overlooked role of weight-decay coupling in shaping the implicit biases of optimizers.

Optimizer choice matters for the emergence of Neural Collapse

TL;DR

This work shows that neural collapse (NC) is not universal across optimizers: adaptive methods with decoupled weight decay (e.g., AdamW) can prevent NC, while SGD and Adam with coupled weight decay promote NC. The authors introduce NC0 as a practical diagnostic and prove theoretical results linking NC0 dynamics to optimizer settings, including exponential decay under SGD with weight decay and momentum. Through nearly 3,900 training runs on multiple architectures and datasets, they reveal a pronounced role for weight-decay coupling and momentum in shaping NC, and they demonstrate that coupling WD preserves the duality and symmetry underlying NC. The findings highlight how implicit biases of optimizers, particularly weight-decay coupling, influence the geometric organization of representations, with implications for understanding training dynamics and generalization in realistic deep networks.

Abstract

Neural Collapse (NC) refers to the emergence of highly symmetric geometric structures in the representations of deep neural networks during the terminal phase of training. Despite its prevalence, the theoretical understanding of NC remains limited. Existing analyses largely ignore the role of the optimizer, thereby suggesting that NC is universal across optimization methods. In this work, we challenge this assumption and demonstrate that the choice of optimizer plays a critical role in the emergence of NC. The phenomenon is typically quantified through NC metrics, which, however, are difficult to track and analyze theoretically. To overcome this limitation, we introduce a novel diagnostic metric, NC0, whose convergence to zero is a necessary condition for NC. Using NC0, we provide theoretical evidence that NC cannot emerge under decoupled weight decay in adaptive optimizers, as implemented in AdamW. Concretely, we prove that SGD, SignGD with coupled weight decay (a special case of Adam), and SignGD with decoupled weight decay (a special case of AdamW) exhibit qualitatively different NC0 dynamics. Also, we show the accelerating effect of momentum on NC (beyond convergence of train loss) when trained with SGD, being the first result concerning momentum in the context of NC. Finally, we conduct extensive empirical experiments consisting of 3,900 training runs across various datasets, architectures, optimizers, and hyperparameters, confirming our theoretical results. This work provides the first theoretical explanation for optimizer-dependent emergence of NC and highlights the overlooked role of weight-decay coupling in shaping the implicit biases of optimizers.
Paper Structure (55 sections, 13 theorems, 114 equations, 57 figures, 7 tables)

This paper contains 55 sections, 13 theorems, 114 equations, 57 figures, 7 tables.

Key Result

Proposition 2.1

NC2 and NC3 implies NC0.

Figures (57)

  • Figure 1: NC0 weakly correlates with NC3 across different optimizers and learning rates. Details on the regression fit can be found in \ref{['subsection:experiment:regression_fit']}
  • Figure 2: NC0 and NC3 metrics at the end of training for a ResNet9 trained on FashionMNIST for Signum and SignumW (left side) and SGD and SGDW (right side). Shaded area refers to one standard deviation across all trainings run with corresponding optimizer. Note that there are fewer values for Signum and SGD as the model did not train due to over regularization for too large WD.
  • Figure 3: Train loss, train accuracy and NC metrics for fixed WD=0.005 and mom=0.7 and 0.9. Although both runs converge to almost exactly the same train loss, the final NC metrics differ considerably. Plots including NC1 and NC2 can be found in \ref{['fig:Fig4_fixed_WD_varying_momentum_Fashion_ResNet9_SGDMW_mom=7e-1vs9e-1']}.
  • Figure 4: Heatmap of NC0, NC2 and NC3 for varying values of momentum and weight decay on ResNet9 trained on FashionMNIST with SGD.
  • Figure 5: NC metrics throughout training on a ResNet9 trained on FashionMNIST.
  • ...and 52 more figures

Theorems & Definitions (24)

  • Proposition 2.1
  • proof
  • Theorem 3.1: SGD with decoupled weight decay promotes NC0
  • proof
  • Theorem 3.2: SGD with coupled weight decay promotes NC0
  • proof
  • Theorem 3.3: Sign GD with decoupled weight decay avoids NC0
  • Theorem 3.4: Sign GD with coupled weight decay can lead to NC0
  • proof
  • Theorem C.1: Theorem 3.1 and 3.2 in zhu2021geometric
  • ...and 14 more