Table of Contents
Fetching ...

Feature learning as alignment: a structural property of gradient descent in non-linear neural networks

Daniel Beaglehole, Ioannis Mitliagkas, Atish Agarwala

TL;DR

It is proved that the derivative alignment occurs almost surely in specific high dimensional settings, and a simple optimization rule is introduced motivated by the analysis of the centered correlation which dramatically increases the NFA correlations at any given layer and improves the quality of features learned.

Abstract

Understanding the mechanisms through which neural networks extract statistics from input-label pairs through feature learning is one of the most important unsolved problems in supervised learning. Prior works demonstrated that the gram matrices of the weights (the neural feature matrices, NFM) and the average gradient outer products (AGOP) become correlated during training, in a statement known as the neural feature ansatz (NFA). Through the NFA, the authors introduce mapping with the AGOP as a general mechanism for neural feature learning. However, these works do not provide a theoretical explanation for this correlation or its origins. In this work, we further clarify the nature of this correlation, and explain its emergence. We show that this correlation is equivalent to alignment between the left singular structure of the weight matrices and the newly defined pre-activation tangent features at each layer. We further establish that the alignment is driven by the interaction of weight changes induced by SGD with the pre-activation features, and analyze the resulting dynamics analytically at early times in terms of simple statistics of the inputs and labels. We prove the derivative alignment occurs almost surely in specific high dimensional settings. Finally, we introduce a simple optimization rule motivated by our analysis of the centered correlation which dramatically increases the NFA correlations at any given layer and improves the quality of features learned.

Feature learning as alignment: a structural property of gradient descent in non-linear neural networks

TL;DR

It is proved that the derivative alignment occurs almost surely in specific high dimensional settings, and a simple optimization rule is introduced motivated by the analysis of the centered correlation which dramatically increases the NFA correlations at any given layer and improves the quality of features learned.

Abstract

Understanding the mechanisms through which neural networks extract statistics from input-label pairs through feature learning is one of the most important unsolved problems in supervised learning. Prior works demonstrated that the gram matrices of the weights (the neural feature matrices, NFM) and the average gradient outer products (AGOP) become correlated during training, in a statement known as the neural feature ansatz (NFA). Through the NFA, the authors introduce mapping with the AGOP as a general mechanism for neural feature learning. However, these works do not provide a theoretical explanation for this correlation or its origins. In this work, we further clarify the nature of this correlation, and explain its emergence. We show that this correlation is equivalent to alignment between the left singular structure of the weight matrices and the newly defined pre-activation tangent features at each layer. We further establish that the alignment is driven by the interaction of weight changes induced by SGD with the pre-activation features, and analyze the resulting dynamics analytically at early times in terms of simple statistics of the inputs and labels. We prove the derivative alignment occurs almost surely in specific high dimensional settings. Finally, we introduce a simple optimization rule motivated by our analysis of the centered correlation which dramatically increases the NFA correlations at any given layer and improves the quality of features learned.
Paper Structure (48 sections, 7 theorems, 63 equations, 17 figures)

This paper contains 48 sections, 7 theorems, 63 equations, 17 figures.

Key Result

Proposition 2

Figures (17)

  • Figure 1: Various feature learning measures for target function $y(x) = \sum_{k=1}^{r} x_{k \text{ mod } r} \cdot x_{(k+1) \text{ mod } r}$ with $r=5$ and inputs drawn from standard normal. The EGOP $\mathbb{E}_{x \sim \mu}\left[\frac{\partial y}{\partial x} \frac{\partial y}{\partial x}^\top\right]$ (first plot) captures the low-rank structure of the task. The NFM $\left(W^\top W\right)$ (second plot) and AGOP $\left(W^\top K W\right)$ (third plot) of a fully-connected network are similar to each other and the EGOP. Replacing $K$ with a symmetric matrix $Q$ with the same spectrum but independent eigenvectors obscures the low rank structure (fourth plot), and reduces the correlation from $\rho\left(F, \bar{G}\right) = 0.93$ to $\rho\left(F, W^\top Q W\right) = 0.53$.
  • Figure 2: Uncentered and centered neural feature correlations across (A,B) fully-connected, (C) convolutional, and (D) attention layers with large initialization scale. (A,C,D) show trajectories of C/UC-NFC over training. (B) shows NFC values across all layers of an MLP with five hidden layers, averaged over CIFAR-10, CIFAR-100, SVHN, MNIST, GTSRB, and STL-10 datasets. (A-C) are additionally averaged over three random seeds. Each row of (D) is an attention block (ordered from first to last in the GPT model), while the columns show correlations for query, key, and value layers, respectively.
  • Figure 3: Predicted versus observed correlation of the second derivatives of centered $F$ and $\bar{G}$ on the alignment reversing dataset. Different shaded color curves correspond to four different seeds for the dataset. The solid blue curve is the average over all data seeds. The rightmost sub-figure is a scatter plot of the predicted versus observed correlations of these second derivatives, with one point for each balance value. We instantiate the dataset in the proportional regime where width, input dimension, and dataset size are all equal to $1024$.
  • Figure 4: The effect of SLO on C/UC neural feature correlations and feature learning on the chain monomial task. In the first two rows, we plot the uncentered and centered NFA for the first layer weight matrix as a function of initialization, (A) with standard training and (B) with SLO. We consider a two hidden layer network with ReLU activations, where we set $C_0 = 500$, and $C_1 = C_2 = 0.002$. The third column shows the ratio of the unnormalized C-NFC to the UC-NFC: ${\text{tr}\left(\bar{W}^\top \bar{W} \bar{W}^\top K \bar{W}\right) \cdot {\text{tr}\left(W^\top W W^\top K W\right)}^{-1}}$. The fourth column shows the training loss. In the third row, we plot the NFM and AGOP from a trained network with (C) standard training and (D) with Speed Limited Optimization with fixed initialization scale of $1.0$.
  • Figure 5: Ratio of the unnormalized double-centered NFC to the centered NFC throughout neural network training. In particular, we plot $\text{tr}\left(\bar{W}^\top \bar{W} \bar{W}^\top \bar{K} \bar{W}\right) \cdot {\text{tr}\left(\bar{W}^\top \bar{W} \bar{W}^\top K \bar{W}\right)}^{-1}$ throughout training for both layers of a two-hidden layer MLP with ReLU activations.
  • ...and 12 more figures

Theorems & Definitions (12)

  • Proposition 2: Alignment decomposition of NFC
  • Proposition 3: Centered NFC dynamics
  • proof : Proof of Proposition \ref{['prop: Early NFA dynamics']}
  • Theorem 9: Maximum C-NFC
  • Proposition 12: Pre-activation to neural tangent identity
  • proof : Proof of Proposition \ref{['prop: PTK->NTK']}
  • Theorem : Maximum C-NFC
  • proof : Proof of Theorem \ref{['thm: high c-nfc']}
  • Lemma 13
  • proof
  • ...and 2 more