Table of Contents
Fetching ...

Feature Learning Beyond the Edge of Stability

Dávid Terjék

TL;DR

The paper addresses training deep networks at learning rates beyond the Edge of Stability (EOS) by connecting sharpness and feature learning through a Taylor expansion of the minibatch loss. It introduces a homogeneous Normalized Update Parameterization (\nu) with a polynomial hidden-layer width pattern and derives closed-form expressions for the first three Taylor coefficients of the minibatch loss under SGD, expressed via low-dimensional tensors and Gram matrices. A cheap depthwise gradient-scaling rule $\xi_{k,t}$ is proposed to cancel forward/backward norms, enabling training far beyond EOS with improved feature learning and implicit sharpness regularization, supported by MNIST experiments. The combination of \nu P with quadratic width ($r=2$) yields stable descent at high learning rates and higher last-layer soft ranks, suggesting practical gains in feature extraction and implicit regularization without numerical explosions. Overall, the work provides a principled pathway to leverage large learning rates for enhanced representation learning in MLPs, with clear directions for extending the theory to broader architectures and datasets.

Abstract

We propose a homogeneous multilayer perceptron parameterization with polynomial hidden layer width pattern and analyze its training dynamics under stochastic gradient descent with depthwise gradient scaling in a general supervised learning scenario. We obtain formulas for the first three Taylor coefficients of the minibatch loss during training that illuminate the connection between sharpness and feature learning, providing in particular a soft rank variant that quantifies the quality of learned hidden layer features. Based on our theory, we design a gradient scaling scheme that in tandem with a quadratic width pattern enables training beyond the edge of stability without loss explosions or numerical errors, resulting in improved feature learning and implicit sharpness regularization as demonstrated empirically.

Feature Learning Beyond the Edge of Stability

TL;DR

The paper addresses training deep networks at learning rates beyond the Edge of Stability (EOS) by connecting sharpness and feature learning through a Taylor expansion of the minibatch loss. It introduces a homogeneous Normalized Update Parameterization (\nu) with a polynomial hidden-layer width pattern and derives closed-form expressions for the first three Taylor coefficients of the minibatch loss under SGD, expressed via low-dimensional tensors and Gram matrices. A cheap depthwise gradient-scaling rule is proposed to cancel forward/backward norms, enabling training far beyond EOS with improved feature learning and implicit sharpness regularization, supported by MNIST experiments. The combination of \nu P with quadratic width () yields stable descent at high learning rates and higher last-layer soft ranks, suggesting practical gains in feature extraction and implicit regularization without numerical explosions. Overall, the work provides a principled pathway to leverage large learning rates for enhanced representation learning in MLPs, with clear directions for extending the theory to broader architectures and datasets.

Abstract

We propose a homogeneous multilayer perceptron parameterization with polynomial hidden layer width pattern and analyze its training dynamics under stochastic gradient descent with depthwise gradient scaling in a general supervised learning scenario. We obtain formulas for the first three Taylor coefficients of the minibatch loss during training that illuminate the connection between sharpness and feature learning, providing in particular a soft rank variant that quantifies the quality of learned hidden layer features. Based on our theory, we design a gradient scaling scheme that in tandem with a quadratic width pattern enables training beyond the edge of stability without loss explosions or numerical errors, resulting in improved feature learning and implicit sharpness regularization as demonstrated empirically.

Paper Structure

This paper contains 6 sections, 10 theorems, 123 equations, 1 figure.

Key Result

Theorem 6

For all $t \in \mathbb{N}$, we have with $\xi_t, \tau_t, T^{(1)}_t \in \mathbb{R}^{l+1}$, $T^{(2)}_t \in \mathbb{R}^{(l+1) \times (l+1)}$ and $T^{(3)}_t \in \mathbb{R}^{(l+1) \times (l+1) \times (l+1)}$ defined as

Figures (1)

  • Figure 1: Training $\nu$P MLPs with $l=8$ and $(a,b)=(\frac{1}{2},\frac{1}{2})$ to minimize the classification loss $\ell(x,y) = \log(\langle e^z, \mathbbm{1}_n \rangle) - z_y$ on MNIST with minibatch size $n=2^8$.

Theorems & Definitions (24)

  • Definition 1: $(a,b)$-ReLU
  • Definition 2: Forward pass
  • Definition 3: Normalized Update Parameterization ($\nu$P)
  • Definition 4: Backward pass
  • Definition 5: Soft rank
  • Theorem 6: Taylor Coefficients
  • Proposition 7: Squared Frobenius Norm of Cumulative Updates
  • Proposition 8: Factorization of Gradients
  • proof
  • Proposition 9: Squared Frobenius Norm of Gradients
  • ...and 14 more