Table of Contents
Fetching ...

Deep Neural Regression Collapse

Akshay Rangamani, Altay Unal

Abstract

Neural Collapse is a phenomenon that helps identify sparse and low rank structures in deep classifiers. Recent work has extended the definition of neural collapse to regression problems, albeit only measuring the phenomenon at the last layer. In this paper, we establish that Neural Regression Collapse (NRC) also occurs below the last layer across different types of models. We show that in the collapsed layers of neural regression models, features lie in a subspace that corresponds to the target dimension, the feature covariance aligns with the target covariance, the input subspace of the layer weights aligns with the feature subspace, and the linear prediction error of the features is close to the overall prediction error of the model. In addition to establishing Deep NRC, we also show that models that exhibit Deep NRC learn the intrinsic dimension of low rank targets and explore the necessity of weight decay in inducing Deep NRC. This paper provides a more complete picture of the simple structure learned by deep networks in the context of regression.

Deep Neural Regression Collapse

Abstract

Neural Collapse is a phenomenon that helps identify sparse and low rank structures in deep classifiers. Recent work has extended the definition of neural collapse to regression problems, albeit only measuring the phenomenon at the last layer. In this paper, we establish that Neural Regression Collapse (NRC) also occurs below the last layer across different types of models. We show that in the collapsed layers of neural regression models, features lie in a subspace that corresponds to the target dimension, the feature covariance aligns with the target covariance, the input subspace of the layer weights aligns with the feature subspace, and the linear prediction error of the features is close to the overall prediction error of the model. In addition to establishing Deep NRC, we also show that models that exhibit Deep NRC learn the intrinsic dimension of low rank targets and explore the necessity of weight decay in inducing Deep NRC. This paper provides a more complete picture of the simple structure learned by deep networks in the context of regression.
Paper Structure (27 sections, 1 theorem, 3 equations, 9 figures, 2 tables)

This paper contains 27 sections, 1 theorem, 3 equations, 9 figures, 2 tables.

Key Result

Proposition 1

Let $\bm{H} \in \mathbb{R}^{N \times h}$ be the centered feature matrix from any layer in a deep regression model and $\bm{Y} \in \mathbb{R}^{N \times t}$ be the centered targets. Let $\mathcal{P}_{\bm{Y}}$ be the projection onto the target subspace, $\bm{H}_{sig} = \mathcal{P}_{\bm{U}}$ denote the

Figures (9)

  • Figure 1: Deep NRC in ResNets: NRC measurements from a ResNet34 trained on the age-regression task in UTKFace (left column) and a ResNet18 trained on Carla2D (right column). The vertical green line in all plots indicates the first collapsed layer. First row (NRC1) shows the noise component being a small fraction of the energy in the collapsed layer representations. The second row shows the CKA between layer features and the target (NRC2). The third row (NRC3) shows the alignment between the features and the weights in the collapsed layers and the final row (NRC4) shows the MSE of linearly predicting the targets from the features in each layer.
  • Figure 2: Deep NRC in MLPs: NRC measurements from $8$-layer MLPs trained on imitation learning tasks in the MuJoCo Swimmer (left column) and Hopper (right column) environments. The vertical green line in all plots indicates the first collapsed layer. Top row (NRC1) shows the noise component being a small fraction of the energy in the collapsed layer representations. The second row shows the CKA between layer features and the target (NRC2). The third row (NRC3) shows the alignment between the features and the weights in the collapsed layers and the bottom row (NRC4) shows the MSE of linearly predicting the targets from the features in each layer.
  • Figure 3: Learning intrinsic dimension of low rank targets: Top row - $8$-layer, $256$-width MLP trained on SGEMM ($4$-dim target, rank-$1$). Bottom row - $10$-layer, $1024$-width MLP trained on synthetic low rank nonlinear dataset ($10$-dim target, rank-$2$). In the left column we plot the noise component using the NRC1 formula, but measured using the bottom ($h-r$) dimensional subspace, not ($h-t$). In the middle column, we plot the stable rank of the layer features and observe that it matches the target stable rank in the collapsed layers. In the right column, we plot the feature-weight alignment (NRC3) using the top-$r$ dimensional subspace in purple and the alignment between the features and the bottom ($h-r$) dimensional subspace of the weights in salmon. These plots establish that collapsed layers learn the intrinsic dimension of the data.
  • Figure 4: Effect of weight decay: Left Column - We train ResNet18s on CARLA2D with varying values of weight decay, and observe the effects on deep neural regression collapse. In the top and middle rows we plot the NRC1 and NRC3 metrics. The bottom row shows the stable rank of the weights. In each plot, the measurements with the right value of weight decay ($\lambda = 5e-3$) are shown in purple, while the measurements with smaller values of weight decay ($\lambda \in [5e-4, 5e-5, 0]$) are in salmon. This shows that weight decay is necessary to achieve feature-weight alignment, which implies a low rank bias in the weights of the layers. In the right column we explore through an experiment on synthetic data, how larger values of weight decay ($\lambda = 1e-3$) can induce NRC1 and NRC3 (top and middle rows), but perform worse at prediction (bottom row). While weight decay is necessary for observing Deep NRC, too high a value can hinder learning.
  • Figure 5: Loss Plots: The training and test loss plots for the collapsed models trained on SGEMM (top left), Swimmer (top right), Reacher (middle left), Hopper (middle right), Carla2D (bottom left), and UTKFace (bottom right) are shown. The vertical green line indicates the epoch at which the trained models first start experiencing Deep NRC. We observe that the collapsed models exhibit Deep NRC after both training and test losses are stabilized. We also observe that the generalization gap is small for the models experiencing Deep NRC, showing that Deep NRC promotes the generalization capabilities of the deep learning models.
  • ...and 4 more figures

Theorems & Definitions (2)

  • Proposition 1
  • proof