Table of Contents
Fetching ...

Low Rank Gradients and Where to Find Them

Rishi Sonthalia, Michael Murray, Guido Montúfar

TL;DR

We study gradient structure in two-layer networks trained under anisotropic, ill-conditioned data with spikes. The central finding is that the input-weight gradient is generically well-approximated by a rank-two matrix formed by a residue-aligned term and a data-spike-aligned term, with an interpolant capturing their interaction. Activation choices, scaling regimes (MF vs NTK), and regularizers (weight decay, input noise, Jacobian penalties) modulate the two components, leading to regimes where one component dominates or both coexist. These insights illuminate how feature learning is guided by data structure and Regularization, and are validated on synthetic data and real embeddings (MNIST/CIFAR).

Abstract

This paper investigates low-rank structure in the gradients of the training loss for two-layer neural networks while relaxing the usual isotropy assumptions on the training data and parameters. We consider a spiked data model in which the bulk can be anisotropic and ill-conditioned, we do not require independent data and weight matrices and we also analyze both the mean-field and neural-tangent-kernel scalings. We show that the gradient with respect to the input weights is approximately low rank and is dominated by two rank-one terms: one aligned with the bulk data-residue , and another aligned with the rank one spike in the input data. We characterize how properties of the training data, the scaling regime and the activation function govern the balance between these two components. Additionally, we also demonstrate that standard regularizers, such as weight decay, input noise and Jacobian penalties, also selectively modulate these components. Experiments on synthetic and real data corroborate our theoretical predictions.

Low Rank Gradients and Where to Find Them

TL;DR

We study gradient structure in two-layer networks trained under anisotropic, ill-conditioned data with spikes. The central finding is that the input-weight gradient is generically well-approximated by a rank-two matrix formed by a residue-aligned term and a data-spike-aligned term, with an interpolant capturing their interaction. Activation choices, scaling regimes (MF vs NTK), and regularizers (weight decay, input noise, Jacobian penalties) modulate the two components, leading to regimes where one component dominates or both coexist. These insights illuminate how feature learning is guided by data structure and Regularization, and are validated on synthetic data and real embeddings (MNIST/CIFAR).

Abstract

This paper investigates low-rank structure in the gradients of the training loss for two-layer neural networks while relaxing the usual isotropy assumptions on the training data and parameters. We consider a spiked data model in which the bulk can be anisotropic and ill-conditioned, we do not require independent data and weight matrices and we also analyze both the mean-field and neural-tangent-kernel scalings. We show that the gradient with respect to the input weights is approximately low rank and is dominated by two rank-one terms: one aligned with the bulk data-residue , and another aligned with the rank one spike in the input data. We characterize how properties of the training data, the scaling regime and the activation function govern the balance between these two components. Additionally, we also demonstrate that standard regularizers, such as weight decay, input noise and Jacobian penalties, also selectively modulate these components. Experiments on synthetic and real data corroborate our theoretical predictions.

Paper Structure

This paper contains 34 sections, 31 theorems, 187 equations, 12 figures, 1 table.

Key Result

Proposition 2.1

If Assumption assumption:activation holds and $R$ is differentiable, then exists for almost every $W$ in $\mathbb{R}^{m \times d}$.

Figures (12)

  • Figure 1: Singular value distribution of the gradient $G$ for varying activation, loss and $\nu$ and weight distribution. Red, and blue lines show the singular value of $S_1$, and $S_2$ respectively. In (a) the rows of $W$ are i.i.d. uniformly random on the unit sphere, we denote this $W=W_S$. In (b) and (c) then $W = W_S + n^{-1/4}\mathbf{1}q^T$, where $W$ is then normalized. The following parameters are constant across all experiments: $\alpha =0$, $\gamma_m = \frac{1}{\sqrt{m}}$ (NTK) $n = 750, d=1000, m = 1250$. The targets $y$ are given by a triple index model, see Appendix \ref{['app:emp']}. For $\nu < 0.25$, a single residue-aligned spike is seen for both isotropic and non-isotropic $W$. For $\nu \in [0.25, 0.5)$, the gradient is approximately rank two.
  • Figure 2: ReLU suppresses the residue spike ($S_1$) compared to smooth activations. Fixed parameters: $\nu = 1/8$, $\alpha = 5/9$, $n = 750$, $d= 1000$, and $m = 1250$.
  • Figure 3: Singular value distributions of the gradient $G$ under various activation functions and weight matrix initializations and structures, with a large data spike $\nu = 3/4$. $W_S$ denotes the random matrix with rows drawn mutually i.i.d. uniformly from the unit sphere. The rows of $W_{S \perp q}$ are uniform on the sphere and orthogonal to $q$. All weight matrices are subsequently normalized to have unit norm rows. Fixed parameters: bulk decay exponent $\alpha = 0$, $n=750$, $d=1000$, $m=1250$, NTK-like scaling ($\gamma_m = 1/\sqrt{m}$), MSE loss, and triple-index model targets.
  • Figure 4: Empirical alignment (normalized inner product) of the top singular vector of the gradient $G$ with $X_B^T y$, $X_B^T r$ and $\omega$ for data from a single-index model $y= \text{Sigmoid}(\omega^T x) + \text{noise}$. We use isotropic $X$, ReLU activation, and MSE loss. We average over 500 samples of $a,W,X,y$. The error bars are the 25th and 75th percentile.
  • Figure 5: Evolution of the gradient direction and weight matrix during training under GD with Weight Normalization (WN) for the MF and NTK scalings. Fixed parameters are $\nu = 0, \alpha = 0$ while using the Sigmoid activation function and the MSE loss. Plots (a) and (b) show the alignment (normalized inner product) between the leading left singular vector of the initial gradient $G_0$ (epoch 0) and that of $G_t$ (epoch $t$). Plot (c) shows the mean principal angle between the weight matrices learned under the MF and NTK scalings with identical initialization and training data.
  • ...and 7 more figures

Theorems & Definitions (52)

  • Proposition 2.1: Gradient of the loss
  • Theorem 3.1: Gradient approximation
  • Proposition 3.1: ReLU gradient
  • Remark 1: Convolutional filters inherit the rank-two gradient
  • Theorem 3.2: Large data-spike gradient approximation
  • Proposition 4.1
  • Proposition 4.2: Isotropic Gaussian noise
  • Proposition 4.3: Gradient penalty
  • Proposition 4.4
  • Proposition A.1
  • ...and 42 more