Table of Contents
Fetching ...

Neural Collapse versus Low-rank Bias: Is Deep Neural Collapse Really Optimal?

Peter Súkeník, Marco Mondelli, Christoph Lampert

TL;DR

This work focuses on non-linear models of arbitrary depth in multi-class classification and reveals a surprising qualitative shift as soon as the model goes beyond two layers or two classes: DNC stops being optimal for the deep unconstrained features model (DUFM) -- the standard theoretical framework for the analysis of collapse.

Abstract

Deep neural networks (DNNs) exhibit a surprising structure in their final layer known as neural collapse (NC), and a growing body of works has currently investigated the propagation of neural collapse to earlier layers of DNNs -- a phenomenon called deep neural collapse (DNC). However, existing theoretical results are restricted to special cases: linear models, only two layers or binary classification. In contrast, we focus on non-linear models of arbitrary depth in multi-class classification and reveal a surprising qualitative shift. As soon as we go beyond two layers or two classes, DNC stops being optimal for the deep unconstrained features model (DUFM) -- the standard theoretical framework for the analysis of collapse. The main culprit is a low-rank bias of multi-layer regularization schemes: this bias leads to optimal solutions of even lower rank than the neural collapse. We support our theoretical findings with experiments on both DUFM and real data, which show the emergence of the low-rank structure in the solution found by gradient descent.

Neural Collapse versus Low-rank Bias: Is Deep Neural Collapse Really Optimal?

TL;DR

This work focuses on non-linear models of arbitrary depth in multi-class classification and reveals a surprising qualitative shift as soon as the model goes beyond two layers or two classes: DNC stops being optimal for the deep unconstrained features model (DUFM) -- the standard theoretical framework for the analysis of collapse.

Abstract

Deep neural networks (DNNs) exhibit a surprising structure in their final layer known as neural collapse (NC), and a growing body of works has currently investigated the propagation of neural collapse to earlier layers of DNNs -- a phenomenon called deep neural collapse (DNC). However, existing theoretical results are restricted to special cases: linear models, only two layers or binary classification. In contrast, we focus on non-linear models of arbitrary depth in multi-class classification and reveal a surprising qualitative shift. As soon as we go beyond two layers or two classes, DNC stops being optimal for the deep unconstrained features model (DUFM) -- the standard theoretical framework for the analysis of collapse. The main culprit is a low-rank bias of multi-layer regularization schemes: this bias leads to optimal solutions of even lower rank than the neural collapse. We support our theoretical findings with experiments on both DUFM and real data, which show the emergence of the low-rank structure in the solution found by gradient descent.
Paper Structure (30 sections, 13 theorems, 96 equations, 11 figures)

This paper contains 30 sections, 13 theorems, 96 equations, 11 figures.

Key Result

Theorem 4

If $K\ge6, L\ge 4$ or $K\ge 10, L=3$ and $d_l\ge K$ for all $l$, then $\mathcal{L}_{SRG} < \mathcal{L}_{DNC}.$ Moreover, consider any sequence of $L$-DUFM problems for which $K\xrightarrow[]{}\infty$ so that $0.499 > \mathcal{L}_{DNC}$ for each problem. In that case,

Figures (11)

  • Figure 1: Strongly regular graph (SRG) solution with $L=4$, $K=10$ and $r=5$. Left: Class-mean matrix of the third layer $M_3$. The non-zero entries of each row have the same value and their number is $r-1$, which corresponds to the degree of the complete graph $\mathcal{K}_r$. Middle: Class-mean matrix of the fourth layer before ReLU $\Tilde{M}_4$ (middle left), and its Gram matrix $\Tilde{M}_4^T\Tilde{M}_4$ (middle right). The SRG construction has very low rank before ReLU: ${\rm rank}(\Tilde{M}_4)=r$ and ${\rm rank}(\sigma(\Tilde{M}_4))=K$. Right:$\Tilde{M}_4^T\Tilde{M}_4$ for DNC. The DNC solution has rank $K$ in all layers before and after ReLU.
  • Figure 2: Training loss compared against DNC and SRG losses (left), DNC1 metric training progression (middle) and singular value distribution at convergence (right). Top row:$4$-DUFM training with $K=10$, $\lambda=0.004$ for all regularization parameters, learning rate of $0.5$ and width $30$. Results are averaged over 10 runs, and we show the confidence intervals at $1$ standard deviation. Bottom row: Training of a ResNet20 with a 4-layer MLP head on CIFAR10, using a DUFM-like regularization. We use weight decay $0.005$ except $\lambda_{H_1}=0.000005$ (to compensate for $n=5000$, which significantly influences the total regularization strength), learning rate $0.05$ and width $64$ for all the MLP layers. Results are averaged over 5 runs, and we show the confidence intervals at $1$ standard deviation.
  • Figure 3: All experiments refer to the training of an $L$-DUFM model. Results are averaged over 5 runs, and we show the confidence intervals at $1$ standard deviation. Left: Ratio between SRG and DNC loss ($\mathcal{L}_{SRG}/\mathcal{L}_{DNC}$), as a function of $r$, where the number of classes is $K= {r\choose2}$. Different curves correspond to different values of $L\in\{3, 4, 5\}$. Middle: Average rank at convergence, as a function of the weight decay in $\log_2$-scale, when $L=4$ and $K=15$. Right: Empirical probability of finding a DNC solution as a function of the width, when $L=4$ and $K=10$.
  • Figure 4: Training of a ResNet20 with a 5-layer MLP head on CIFAR-10 (top row) and MNIST (bottom row), using the standard regularization. We pick a large weight decay ($0.08$ for CIFAR-10 and $0.04$ for MNIST) and a large learning rate ($0.005$ for CIFAR-10 and $0.01$ for MNIST). Results are averaged over 5 runs, and we show the confidence intervals at $1$ standard deviation. Left: DNC1 metric training progression. Middle: Singular value distributions at convergence for all the layers. Right: Gram matrices of $M_3$ (CIFAR-10) and $M_5$ (MNIST).
  • Figure 5: $4$-DUFM training for $K=3$ (top), $K=4$ (middle), and $K=5$ (bottom). Left: Loss progression, also decomposed into the fit and regularization terms. Middle left: Visualization of the matrix $M_3$. Middle right: Visualization of the matrix $\Tilde{M}_4$. Right: Visualization of the matrix $M_3^T M_3$.
  • ...and 6 more figures

Theorems & Definitions (32)

  • Definition 1
  • Definition 2
  • Definition 2
  • Definition 3
  • Theorem 4
  • Theorem 5
  • Definition 5
  • Theorem 6
  • Definition 7
  • Definition 8
  • ...and 22 more