On the Weight Dynamics of Deep Normalized Networks
Christian H. X. Ali Mehmeti-Göpel, Michael Wand
TL;DR
Problem: ELR disparities across layers in normalization-based networks hinder trainability. Approach: builds a discrete/continuous dynamical model of weight and gradient norms and ELR evolution, predicting when ELRs converge or flip with respect to a constant learning rate. Contributions: (i) a general auto rate-tuning theory with closed-form gradient-flow solution $\frac{d\sigma^2}{dt}=\frac{c^2}{\sigma^2}$ and ELR limit $\lim_{t\to\infty} \frac{E_\ell}{E_k}=1$, (ii) identification of regime boundaries and a hyperparameter-free warm-up method, (iii) empirical validation on CNNs and Transformers and a constrained-ELR training technique. Findings: ELR spread is minimized with small or moderate learning rates and is further reduced by momentum and warm-up, enabling training of very deep networks; constrained ELR training can stabilize training in otherwise unstable regimes. Impact: provides practical stabilization guidelines for deep normalization-based architectures and informs design choices to improve trainability in CNNs and Transformer models.
Abstract
Recent studies have shown that high disparities in effective learning rates (ELRs) across layers in deep neural networks can negatively affect trainability. We formalize how these disparities evolve over time by modeling weight dynamics (evolution of expected gradient and weight norms) of networks with normalization layers, predicting the evolution of layer-wise ELR ratios. We prove that when training with any constant learning rate, ELR ratios converge to 1, despite initial gradient explosion. We identify a ``critical learning rate" beyond which ELR disparities widen, which only depends on current ELRs. To validate our findings, we devise a hyper-parameter-free warm-up method that successfully minimizes ELR spread quickly in theory and practice. Our experiments link ELR spread with trainability, a relationship that is most evident in very deep networks with significant gradient magnitude excursions.
