On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective
Zeke Xie, Zhiqiang Xu, Jingzhao Zhang, Issei Sato, Masashi Sugiyama
TL;DR
The paper identifies overlooked pitfalls of weight decay when training deep networks with adaptive learning rates, notably the emergence of large gradient norms late in training. It introduces a gradient-norm–aware framework and a practical Stable Weight Decay (SWD) scheduler that couples weight decay to the effective learning rate, alongside AdamS, an optimizer combining SWD with Adam. The authors provide theoretical results on stationary-point stability and convergence time, propose a large-batch weight-decay scaling rule, and demonstrate empirical improvements on CIFAR-10/100 with several architectures, while noting more modest gains on ImageNet and language modeling. The work offers a principled path to improve generalization and stability for adaptive gradient methods in standard and large-batch training regimes, with clear directions for future schedulers to optimize early versus late training behavior.
Abstract
Weight decay is a simple yet powerful regularization technique that has been very widely used in training of deep neural networks (DNNs). While weight decay has attracted much attention, previous studies fail to discover some overlooked pitfalls on large gradient norms resulted by weight decay. In this paper, we discover that, weight decay can unfortunately lead to large gradient norms at the final phase (or the terminated solution) of training, which often indicates bad convergence and poor generalization. To mitigate the gradient-norm-centered pitfalls, we present the first practical scheduler for weight decay, called the Scheduled Weight Decay (SWD) method that can dynamically adjust the weight decay strength according to the gradient norm and significantly penalize large gradient norms during training. Our experiments also support that SWD indeed mitigates large gradient norms and often significantly outperforms the conventional constant weight decay strategy for Adaptive Moment Estimation (Adam).
