Table of Contents
Fetching ...

Depth Dependence of $μ$P Learning Rates in ReLU MLPs

Samy Jelassi, Boris Hanin, Ziwei Ji, Sashank J. Reddi, Srinadh Bhojanapalli, Sanjiv Kumar

TL;DR

This work analyzes how the maximal-update μP learning rate for gradient descent in deep ReLU MLPs with mean-field initialization depends on network depth. By deriving a depth-aware recursion for the change in pre-activations, the authors show that the maximal update learning rate scales as $η^*(L) = \text{const}\cdot L^{-3/2}$, while remaining largely independent of width $n$ except for the first and last layers. The key technical contribution is a decomposition of the update variance into terms with a recursive structure across layers, yielding a $Θ(η^2 L^3)$ growth in the relevant quantity and establishing the depth scaling rigorously. This depth-aware μP result informs principled learning-rate selection and suggests potential cross-width transfer of hyperparameters when training deep ReLU networks.

Abstract

In this short note we consider random fully connected ReLU networks of width $n$ and depth $L$ equipped with a mean-field weight initialization. Our purpose is to study the dependence on $n$ and $L$ of the maximal update ($μ$P) learning rate, the largest learning rate for which the mean squared change in pre-activations after one step of gradient descent remains uniformly bounded at large $n,L$. As in prior work on $μ$P of Yang et. al., we find that this maximal update learning rate is independent of $n$ for all but the first and last layer weights. However, we find that it has a non-trivial dependence of $L$, scaling like $L^{-3/2}.$

Depth Dependence of $μ$P Learning Rates in ReLU MLPs

TL;DR

This work analyzes how the maximal-update μP learning rate for gradient descent in deep ReLU MLPs with mean-field initialization depends on network depth. By deriving a depth-aware recursion for the change in pre-activations, the authors show that the maximal update learning rate scales as , while remaining largely independent of width except for the first and last layers. The key technical contribution is a decomposition of the update variance into terms with a recursive structure across layers, yielding a growth in the relevant quantity and establishing the depth scaling rigorously. This depth-aware μP result informs principled learning-rate selection and suggests potential cross-width transfer of hyperparameters when training deep ReLU networks.

Abstract

In this short note we consider random fully connected ReLU networks of width and depth equipped with a mean-field weight initialization. Our purpose is to study the dependence on and of the maximal update (P) learning rate, the largest learning rate for which the mean squared change in pre-activations after one step of gradient descent remains uniformly bounded at large . As in prior work on P of Yang et. al., we find that this maximal update learning rate is independent of for all but the first and last layer weights. However, we find that it has a non-trivial dependence of , scaling like
Paper Structure (7 sections, 10 theorems, 44 equations)

This paper contains 7 sections, 10 theorems, 44 equations.

Key Result

Theorem 1.1

For each $c_1>0$ there exists $c_2,c_3>0$ with the following property. Fix a network width $n$ and depth $L$ so that $L/n < c_1$. Then, where $\mathcal{B}=\left\{(x,y)\right\}$ is any batch of size one consisting of a normalized datapoint $(x,y)$ sampled independent of network weights and biases with:

Theorems & Definitions (19)

  • Theorem 1.1
  • Lemma 2.1
  • proof : Proof of \ref{['lem:lemma_one']}
  • Lemma 2.2
  • proof : Proof of \ref{['lem:lemma_two']}
  • Lemma 2.3
  • proof : Proof of \ref{['lem:lemmatwoprime']}
  • Proposition 2.4
  • proof
  • Lemma 2.5
  • ...and 9 more