How Does Learning Rate Decay Help Modern Neural Networks?
Kaichao You, Mingsheng Long, Jianmin Wang, Michael I. Jordan
TL;DR
This work questions conventional optimization-centric explanations of learning rate decay (lrDecay) in modern neural networks and proposes a pattern-complexity perspective: a large initial learning rate suppresses memorization of noisy data, while learning rate decay enables learning of increasingly complex patterns. The authors validate this view on a tractable PS10 dataset and corroborate it with real-world transferability experiments, showing that later-stage learned patterns are more complex and less transferable. Through targeted experiments, they demonstrate that existing GD/SGD explanations are insufficient to account for lrDecay's effectiveness in deep architectures. The findings offer a new lens for designing training strategies and model zoos, emphasizing pattern complexity and transferability as central factors in lrDecay dynamics.
Abstract
Learning rate decay (lrDecay) is a \emph{de facto} technique for training modern neural networks. It starts with a large learning rate and then decays it multiple times. It is empirically observed to help both optimization and generalization. Common beliefs in how lrDecay works come from the optimization analysis of (Stochastic) Gradient Descent: 1) an initially large learning rate accelerates training or helps the network escape spurious local minima; 2) decaying the learning rate helps the network converge to a local minimum and avoid oscillation. Despite the popularity of these common beliefs, experiments suggest that they are insufficient in explaining the general effectiveness of lrDecay in training modern neural networks that are deep, wide, and nonconvex. We provide another novel explanation: an initially large learning rate suppresses the network from memorizing noisy data while decaying the learning rate improves the learning of complex patterns. The proposed explanation is validated on a carefully-constructed dataset with tractable pattern complexity. And its implication, that additional patterns learned in later stages of lrDecay are more complex and thus less transferable, is justified in real-world datasets. We believe that this alternative explanation will shed light into the design of better training strategies for modern neural networks.
