Depth Dependence of $μ$P Learning Rates in ReLU MLPs
Samy Jelassi, Boris Hanin, Ziwei Ji, Sashank J. Reddi, Srinadh Bhojanapalli, Sanjiv Kumar
TL;DR
This work analyzes how the maximal-update μP learning rate for gradient descent in deep ReLU MLPs with mean-field initialization depends on network depth. By deriving a depth-aware recursion for the change in pre-activations, the authors show that the maximal update learning rate scales as $η^*(L) = \text{const}\cdot L^{-3/2}$, while remaining largely independent of width $n$ except for the first and last layers. The key technical contribution is a decomposition of the update variance into terms with a recursive structure across layers, yielding a $Θ(η^2 L^3)$ growth in the relevant quantity and establishing the depth scaling rigorously. This depth-aware μP result informs principled learning-rate selection and suggests potential cross-width transfer of hyperparameters when training deep ReLU networks.
Abstract
In this short note we consider random fully connected ReLU networks of width $n$ and depth $L$ equipped with a mean-field weight initialization. Our purpose is to study the dependence on $n$ and $L$ of the maximal update ($μ$P) learning rate, the largest learning rate for which the mean squared change in pre-activations after one step of gradient descent remains uniformly bounded at large $n,L$. As in prior work on $μ$P of Yang et. al., we find that this maximal update learning rate is independent of $n$ for all but the first and last layer weights. However, we find that it has a non-trivial dependence of $L$, scaling like $L^{-3/2}.$
