Table of Contents
Fetching ...

Wide Neural Networks Trained with Weight Decay Provably Exhibit Neural Collapse

Arthur Jacot, Peter Súkeník, Zihan Wang, Marco Mondelli

TL;DR

These results are the first to show neural collapse in the end-to-end training of DNNs, and prove generic guarantees on neural collapse that assume low training error and balancedness of the linear layers and bounded conditioning of the features before the linear part.

Abstract

Deep neural networks (DNNs) at convergence consistently represent the training data in the last layer via a highly symmetric geometric structure referred to as neural collapse. This empirical evidence has spurred a line of theoretical research aimed at proving the emergence of neural collapse, mostly focusing on the unconstrained features model. Here, the features of the penultimate layer are free variables, which makes the model data-agnostic and, hence, puts into question its ability to capture DNN training. Our work addresses the issue, moving away from unconstrained features and studying DNNs that end with at least two linear layers. We first prove generic guarantees on neural collapse that assume (i) low training error and balancedness of the linear layers (for within-class variability collapse), and (ii) bounded conditioning of the features before the linear part (for orthogonality of class-means, as well as their alignment with weight matrices). We then show that such assumptions hold for gradient descent training with weight decay: (i) for networks with a wide first layer, we prove low training error and balancedness, and (ii) for solutions that are either nearly optimal or stable under large learning rates, we additionally prove the bounded conditioning. Taken together, our results are the first to show neural collapse in the end-to-end training of DNNs.

Wide Neural Networks Trained with Weight Decay Provably Exhibit Neural Collapse

TL;DR

These results are the first to show neural collapse in the end-to-end training of DNNs, and prove generic guarantees on neural collapse that assume low training error and balancedness of the linear layers and bounded conditioning of the features before the linear part.

Abstract

Deep neural networks (DNNs) at convergence consistently represent the training data in the last layer via a highly symmetric geometric structure referred to as neural collapse. This empirical evidence has spurred a line of theoretical research aimed at proving the emergence of neural collapse, mostly focusing on the unconstrained features model. Here, the features of the penultimate layer are free variables, which makes the model data-agnostic and, hence, puts into question its ability to capture DNN training. Our work addresses the issue, moving away from unconstrained features and studying DNNs that end with at least two linear layers. We first prove generic guarantees on neural collapse that assume (i) low training error and balancedness of the linear layers (for within-class variability collapse), and (ii) bounded conditioning of the features before the linear part (for orthogonality of class-means, as well as their alignment with weight matrices). We then show that such assumptions hold for gradient descent training with weight decay: (i) for networks with a wide first layer, we prove low training error and balancedness, and (ii) for solutions that are either nearly optimal or stable under large learning rates, we additionally prove the bounded conditioning. Taken together, our results are the first to show neural collapse in the end-to-end training of DNNs.
Paper Structure (21 sections, 14 theorems, 105 equations, 5 figures)

This paper contains 21 sections, 14 theorems, 105 equations, 5 figures.

Key Result

Theorem 3.1

If the network satisfies then if $\epsilon_1\le \mathop{\rm min}\nolimits \left(s_K(Y), \sqrt{\frac{(K-1)N}{4K}}\right)$, where $\Psi(\epsilon_1, \epsilon_2, r)=r\left(\frac{\epsilon_{1}}{s_K(Y)-\epsilon_1}+\sqrt{n_{L-1} \epsilon_2}\right)$. If we additionally assume that the linear part of the network is not too ill-conditioned, i.e., $\kappa(W_{L:L_1+1})\le c_3$, then with $\epsilon \space=\

Figures (5)

  • Figure 1: Last 7 layers of a 9-layer MLP trained on MNIST with weight decay $0.0018$ and learning rate $0.001$. Top: NC1s, NC2s, balancednesses and negativities, from left to right. Results are averaged over 5 runs, and the confidence band at 1 standard deviation is displayed. Bottom: Class-mean matrices of the last three layers (i.e., the linear head), the first before the last ReLU.
  • Figure 2: Upper/Lower row: MLP/ResNet20 with a deep linear head. Left to right: NC1 in the last layer; NC1 in the first layer of the linear head; NC2 in the last layer; NC2 in the first layer of the linear head. All plots are a function of the number of layers in the linear head. Results are averaged over 50 runs (5 runs for each of the 10 hyperparameter setups), and the confidence band at 1 standard deviation is displayed.
  • Figure 3: Left to right: Minimum balancedness; mean balancedness; minimum negativity; mean negativity across non-linear layers of the head as a function of the number of non-linear layers. Results are averaged over 10 runs (5 runs for each of the 2 hyperparameter setups), and the confidence band at 1 standard deviation is displayed.
  • Figure 4: Last 7 layers of a 9-layer MLP trained on MNIST with weight decay $0.0018$ and learning rate $0.001$. Top: NC1s, NC2s, balancednesses and negativities, from left to right. Results are averaged over 5 runs, and the confidence band at 1 standard deviation is displayed. Bottom: Class-mean matrices of the last four layers (i.e., the linear head).
  • Figure 5: ResNet20 trained on MNIST with a deep linear head. Left to right: NC1 in the last layer; NC1 in the first layer of the linear head; NC2 in the last layer; NC2 in the first layer of the linear head. All plots are a function of the number of layers in the linear head. Results are based on 50 runs (5 runs for each of the 10 hyperparameter setups), and the confidence band at 1 standard deviation is displayed.

Theorems & Definitions (22)

  • Theorem 3.1
  • Theorem 4.4
  • Proposition 4.4
  • Proposition 5.0
  • Theorem 5.1
  • Proposition 5.1
  • Theorem B.1
  • proof
  • Proposition B.0
  • proof
  • ...and 12 more