Table of Contents
Fetching ...

Neural Collapse is Globally Optimal in Deep Regularized ResNets and Transformers

Peter Súkeník, Christoph H. Lampert, Marco Mondelli

TL;DR

The paper addresses why neural collapse (NC) emerges in modern deep architectures when trained end-to-end. It develops a depth-dependent theory showing that, for deep regularized ResNets and transformers with LayerNorm, the global optima approach a generalized unconstrained features model (GUFM), leading to NC that tightens as depth grows. The authors prove global-optimal NC for deep single- and double-layer blocks under CE and MSE losses, with a formal end-to-end reduction to GUFM and extensive experiments on vision and language tasks confirming depth-enhanced NC. This work provides a principled justification for using UFMs to analyze deep, modern architectures and suggests practical implications for achieving robust, interpretable representations through depth.

Abstract

The empirical emergence of neural collapse -- a surprising symmetry in the feature representations of the training data in the penultimate layer of deep neural networks -- has spurred a line of theoretical research aimed at its understanding. However, existing work focuses on data-agnostic models or, when data structure is taken into account, it remains limited to multi-layer perceptrons. Our paper fills both these gaps by analyzing modern architectures in a data-aware regime: we prove that global optima of deep regularized transformers and residual networks (ResNets) with LayerNorm trained with cross entropy or mean squared error loss are approximately collapsed, and the approximation gets tighter as the depth grows. More generally, we formally reduce any end-to-end large-depth ResNet or transformer training into an equivalent unconstrained features model, thus justifying its wide use in the literature even beyond data-agnostic settings. Our theoretical results are supported by experiments on computer vision and language datasets showing that, as the depth grows, neural collapse indeed becomes more prominent.

Neural Collapse is Globally Optimal in Deep Regularized ResNets and Transformers

TL;DR

The paper addresses why neural collapse (NC) emerges in modern deep architectures when trained end-to-end. It develops a depth-dependent theory showing that, for deep regularized ResNets and transformers with LayerNorm, the global optima approach a generalized unconstrained features model (GUFM), leading to NC that tightens as depth grows. The authors prove global-optimal NC for deep single- and double-layer blocks under CE and MSE losses, with a formal end-to-end reduction to GUFM and extensive experiments on vision and language tasks confirming depth-enhanced NC. This work provides a principled justification for using UFMs to analyze deep, modern architectures and suggests practical implications for achieving robust, interpretable representations through depth.

Abstract

The empirical emergence of neural collapse -- a surprising symmetry in the feature representations of the training data in the penultimate layer of deep neural networks -- has spurred a line of theoretical research aimed at its understanding. However, existing work focuses on data-agnostic models or, when data structure is taken into account, it remains limited to multi-layer perceptrons. Our paper fills both these gaps by analyzing modern architectures in a data-aware regime: we prove that global optima of deep regularized transformers and residual networks (ResNets) with LayerNorm trained with cross entropy or mean squared error loss are approximately collapsed, and the approximation gets tighter as the depth grows. More generally, we formally reduce any end-to-end large-depth ResNet or transformer training into an equivalent unconstrained features model, thus justifying its wide use in the literature even beyond data-agnostic settings. Our theoretical results are supported by experiments on computer vision and language datasets showing that, as the depth grows, neural collapse indeed becomes more prominent.

Paper Structure

This paper contains 20 sections, 11 theorems, 35 equations, 2 figures.

Key Result

Lemma 4.0

Denote as $\operatorname{distmax}(A, B)=\sup\limits_{x\in A}\,\operatorname{dist}(x,B)$ for any sets $A, B.$ Then, we have

Figures (2)

  • Figure 1: $\log_{10}$ of NC1, NC2 and NC3 metrics respectively in the left, middle and right column, as a function of the number of blocks $L$. First row: $L$-RN1 on CIFAR10; second row:$L$-T11 on CIFAR10; third row: pre-LN $L$-T11 on IMDB; Fourth row:$L$-RN2 on MNIST with $\lambda\propto L^{-1}$.
  • Figure 2: MNIST training. $\log_{10}$ of NC1, NC2 and NC3 metrics respectively in the upper, middle and bottom row, as a function of the number of blocks $L$. The architectures are $L$-RN1 with $\lambda=0.005$, $L$-T11 with $\lambda=0.005$, and $L$-RN2 with $\lambda=0.0025$.

Theorems & Definitions (24)

  • Definition 3.1
  • Definition 3.2
  • Remark 3.3
  • Definition 3.4
  • Definition 3.5
  • Lemma 4.0
  • proof
  • Lemma 4.0
  • Theorem 4.1
  • Corollary 4.2
  • ...and 14 more