Table of Contents
Fetching ...

On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective

Zeke Xie, Zhiqiang Xu, Jingzhao Zhang, Issei Sato, Masashi Sugiyama

TL;DR

The paper identifies overlooked pitfalls of weight decay when training deep networks with adaptive learning rates, notably the emergence of large gradient norms late in training. It introduces a gradient-norm–aware framework and a practical Stable Weight Decay (SWD) scheduler that couples weight decay to the effective learning rate, alongside AdamS, an optimizer combining SWD with Adam. The authors provide theoretical results on stationary-point stability and convergence time, propose a large-batch weight-decay scaling rule, and demonstrate empirical improvements on CIFAR-10/100 with several architectures, while noting more modest gains on ImageNet and language modeling. The work offers a principled path to improve generalization and stability for adaptive gradient methods in standard and large-batch training regimes, with clear directions for future schedulers to optimize early versus late training behavior.

Abstract

Weight decay is a simple yet powerful regularization technique that has been very widely used in training of deep neural networks (DNNs). While weight decay has attracted much attention, previous studies fail to discover some overlooked pitfalls on large gradient norms resulted by weight decay. In this paper, we discover that, weight decay can unfortunately lead to large gradient norms at the final phase (or the terminated solution) of training, which often indicates bad convergence and poor generalization. To mitigate the gradient-norm-centered pitfalls, we present the first practical scheduler for weight decay, called the Scheduled Weight Decay (SWD) method that can dynamically adjust the weight decay strength according to the gradient norm and significantly penalize large gradient norms during training. Our experiments also support that SWD indeed mitigates large gradient norms and often significantly outperforms the conventional constant weight decay strategy for Adaptive Moment Estimation (Adam).

On the Overlooked Pitfalls of Weight Decay and How to Mitigate Them: A Gradient-Norm Perspective

TL;DR

The paper identifies overlooked pitfalls of weight decay when training deep networks with adaptive learning rates, notably the emergence of large gradient norms late in training. It introduces a gradient-norm–aware framework and a practical Stable Weight Decay (SWD) scheduler that couples weight decay to the effective learning rate, alongside AdamS, an optimizer combining SWD with Adam. The authors provide theoretical results on stationary-point stability and convergence time, propose a large-batch weight-decay scaling rule, and demonstrate empirical improvements on CIFAR-10/100 with several architectures, while noting more modest gains on ImageNet and language modeling. The work offers a principled path to improve generalization and stability for adaptive gradient methods in standard and large-batch training regimes, with clear directions for future schedulers to optimize early versus late training behavior.

Abstract

Weight decay is a simple yet powerful regularization technique that has been very widely used in training of deep neural networks (DNNs). While weight decay has attracted much attention, previous studies fail to discover some overlooked pitfalls on large gradient norms resulted by weight decay. In this paper, we discover that, weight decay can unfortunately lead to large gradient norms at the final phase (or the terminated solution) of training, which often indicates bad convergence and poor generalization. To mitigate the gradient-norm-centered pitfalls, we present the first practical scheduler for weight decay, called the Scheduled Weight Decay (SWD) method that can dynamically adjust the weight decay strength according to the gradient norm and significantly penalize large gradient norms during training. Our experiments also support that SWD indeed mitigates large gradient norms and often significantly outperforms the conventional constant weight decay strategy for Adaptive Moment Estimation (Adam).

Paper Structure

This paper contains 19 sections, 4 theorems, 22 equations, 24 figures, 4 tables, 8 algorithms.

Key Result

Theorem 1

Suppose learning dynamics is governed by GD with vanilla weight decay (Equation eq:originalwd) and the learning rate $\eta_{t} \in (0, +\infty)$ holds. If $\exists \delta$ that satisfies $0 < \delta \leq |\eta_{t} - \eta_{t+1}|$ for any $t>t_{0}$, then the learning dynamics cannot converge to any n

Figures (24)

  • Figure 1: We compared Equation-\ref{['eq:originalwd']}-based weight decay and Equation-\ref{['eq:dwd']}-based weight decay by training ResNet18 on CIFAR-10 via vanilla SGD. In the presence of a popular learning rate scheduler, Equation-\ref{['eq:dwd']}-based weight decay shows better test performance. It demonstrates that the form $- \eta_{t} \lambda \theta$ is a better weight decay implementation than $- \lambda^{\prime} \theta$.
  • Figure 2: We train ResNet18 via SGD on CIFAR-10 for verifying that $t_{\mathrm{convergence}} = \mathcal{O}\left( \lambda^{-1} \right)$. With the fixed 200 epochs, the optimal weight decay is about $0.0005$. With $0.1\lambda^{-1}$ epochs, decreasing weight decay monotonically decreases test performance. Table \ref{['table:optimalwd']} further supports that the optimal weight decay is approximately inverse to the number of epochs.
  • Figure 3: Large-batch training ($B=16384$) with various learning rates and weight decay. Note that $\eta=10^{-3}$ and $\lambda=10^{-4}$ is the baseline choice for $B=128$. Subfigure (a) and (b) show that, even slightly increasing the learning rate (by multiplying 16) is harmful to optimization convergence. Subfigure (c) shows that Rule \ref{['rule:lrlinearscaling']} is completely invalid in this common large-batch training setting. Subfigure (d) shows that, multiplying weight decay by 128 ($\lambda=0.0128$) has the lowest test error, which fully supports the proposed Rule \ref{['rule:wdlinearscaling']}.
  • Figure 4: The learning curves of AdamS, AdamW, and Adam on CIFAR-10 and CIFAR-100. AdamS shows significantly better generalization than AdamW and Adam.
  • Figure 5: The scatter plot of training losses and test errors during final 40 epochs of training ResNet34 on CIFAR-100. Even with similar or higher training losses, AdamS still generalizes better than other Adam variants. We leave the scatter plot on CIFAR-10 in Appendix \ref{['sec:appresult']}.
  • ...and 19 more figures

Theorems & Definitions (9)

  • Definition 1: The stability of the stationary point
  • Definition 2: Stable Weight Decay
  • Theorem 1: Non-convergence of GD with unstable weight decay
  • Remark
  • Theorem 2: Dynamics of weight decay
  • Corollary 1
  • Corollary 2
  • proof
  • proof