Table of Contents
Fetching ...

Why Do We Need Weight Decay in Modern Deep Learning?

Francesco D'Angelo, Maksym Andriushchenko, Aditya Varre, Nicolas Flammarion

TL;DR

This work investigates why weight decay helps in modern deep learning, arguing that its effect is largely dynamic rather than purely regularization. Through theoretical arguments and extensive experiments on ResNets for vision and transformers for language, the authors show that weight decay with large learning rates sustains non-vanishing SGD noise that regularizes the training via control of the Jacobian, while in one-pass training it modulates the effective learning rate to improve optimization and stability. A unifying view across regimes ties weight decay to either loss stabilization or bias-variance trade-off adjustments, rather than simply constraining model capacity. The findings offer practical guidance for tuning WD with LR and EMA, and illuminate why WD enables stable training in low-precision settings such as bfloat16.

Abstract

Weight decay is a broadly used technique for training state-of-the-art deep networks from image classification to large language models. Despite its widespread usage and being extensively studied in the classical literature, its role remains poorly understood for deep learning. In this work, we highlight that the role of weight decay in modern deep learning is different from its regularization effect studied in classical learning theory. For deep networks on vision tasks trained with multipass SGD, we show how weight decay modifies the optimization dynamics enhancing the ever-present implicit regularization of SGD via the loss stabilization mechanism. In contrast, for large language models trained with nearly one-epoch training, we describe how weight decay balances the bias-variance tradeoff in stochastic optimization leading to lower training loss and improved training stability. Overall, we present a unifying perspective from ResNets on vision tasks to LLMs: weight decay is never useful as an explicit regularizer but instead changes the training dynamics in a desirable way. The code is available at https://github.com/tml-epfl/why-weight-decay

Why Do We Need Weight Decay in Modern Deep Learning?

TL;DR

This work investigates why weight decay helps in modern deep learning, arguing that its effect is largely dynamic rather than purely regularization. Through theoretical arguments and extensive experiments on ResNets for vision and transformers for language, the authors show that weight decay with large learning rates sustains non-vanishing SGD noise that regularizes the training via control of the Jacobian, while in one-pass training it modulates the effective learning rate to improve optimization and stability. A unifying view across regimes ties weight decay to either loss stabilization or bias-variance trade-off adjustments, rather than simply constraining model capacity. The findings offer practical guidance for tuning WD with LR and EMA, and illuminate why WD enables stable training in low-precision settings such as bfloat16.

Abstract

Weight decay is a broadly used technique for training state-of-the-art deep networks from image classification to large language models. Despite its widespread usage and being extensively studied in the classical literature, its role remains poorly understood for deep learning. In this work, we highlight that the role of weight decay in modern deep learning is different from its regularization effect studied in classical learning theory. For deep networks on vision tasks trained with multipass SGD, we show how weight decay modifies the optimization dynamics enhancing the ever-present implicit regularization of SGD via the loss stabilization mechanism. In contrast, for large language models trained with nearly one-epoch training, we describe how weight decay balances the bias-variance tradeoff in stochastic optimization leading to lower training loss and improved training stability. Overall, we present a unifying perspective from ResNets on vision tasks to LLMs: weight decay is never useful as an explicit regularizer but instead changes the training dynamics in a desirable way. The code is available at https://github.com/tml-epfl/why-weight-decay
Paper Structure (19 sections, 1 theorem, 20 equations, 26 figures, 3 tables)

This paper contains 19 sections, 1 theorem, 20 equations, 26 figures, 3 tables.

Key Result

Proposition 2

Assume $\bigl\|{{\mathbf{w}}}\bigr\| \in \left[a,b\right]$, for any $x \in \mathcal{D}$, $\bigl\|{\nabla h\left({\mathbf{w}},x\right)}\bigr\| \in \left[m,M\right]$ holds. For $n$ sufficiently large, there exists constants $c_1,c_2$ such that

Figures (26)

  • Figure 1: Test error vs. dataset size on CIFAR-10-5m for a fixed number of training iteration. Weight decay is helpful in both: the over-training and the under-training, one-pass regime.
  • Figure 2: Training with and w/o weight decay. We report the test error for Resnet18 on CIFAR-10 (\ref{['fig:cifar10_test_e']}) and Tiny-ImageNet (\ref{['fig:timgn_test_e']}) trained with and without weight decay and with small and large learning rates. We also include the correspondent EMA, represented by dashed lines. After the first 250 epochs the learning rate is decayed to $\eta = 10^{-3}$ for all the curves. We report also the L2 norm of the parameters (\ref{['fig:cifar10_L2']}) and Train CE (\ref{['fig:cifar10_CE']}) which after the decay converges to the same value for all the runs with the same $\lambda$.
  • Figure 3: Resnet18 on Tiny-ImageNet. Heatmap of the test error and Jacobian norm for the EMA for all the different combinations of $\eta$ and $\lambda$.
  • Figure 4: Resnet18 on Tiny-ImageNet. Training for $200$ epochs with different $\eta$ and $\lambda$; the scale of the noise monotonically increases with the train loss and $\eta \times \lambda$ Fig. \ref{['fig:train_loss_eta_lambda']}, \ref{['fig:ell_prime_eta_lambda']}. The test error instead, presents an optimal value of $\eta \times \lambda$ Fig. \ref{['fig:scatter_test_e']} while the Jacobian norm decreases monotonically Fig. \ref{['fig:scatter_j_norm']}.
  • Figure 5: EMA vs Fine-tuning. Training of standard Resnet18 on CIFAR-10 for 100 epochs fixing $\lambda = 0.0125$ and varying the learning rate. In Fig. \ref{['fig:resnet18_cifar_CE_main']} we report different levels of loss stabilization, in Fig. \ref{['fig:resnet18_cifar_test_e_main']} we report the test errors and in Fig. \ref{['fig:resnet18_cifar_jnorm_main']} and Fig. \ref{['fig:resnet18_cifar_l2_main']} the norm of the Jacobian and of the weights respectively. The quantities are measured for the SGD iterates, the EMA and the fine-tuning. The latter is performed for 100 epochs every 3 with $\eta = 10^{-3}$.
  • ...and 21 more figures

Theorems & Definitions (4)

  • Conjecture 1
  • Proposition 2
  • proof
  • Conjecture 3