Table of Contents
Fetching ...

Correction of Decoupled Weight Decay

Jason Chuan-Chih Chou

TL;DR

To resolve how decoupled weight decay shapes training dynamics, the paper analyzes weight-norm evolution under steady-state independence and argues that decoupled weight decay should scale with γ^2. It shows the perpendicular component of the update contributes negligibly to weight norms, and derives that TUC is governed by a momentum-dependent effective learning rate γ_eff. The proposed ScionC uses corrected weight decay and momentum-normalized updates, with experiments on ViT-S/16 and Modded-NanoGPT demonstrating more stable weight/gradient norms and competitive accuracy. The work clarifies how to set weight decay in decoupled schemes and highlights practical strategies for stable training across momentum regimes.

Abstract

Decoupled weight decay, solely responsible for the performance advantage of AdamW over Adam, has long been set to proportional to learning rate $γ$ without questioning. Some researchers have recently challenged such assumption and argued that decoupled weight decay should be set $\propto γ^2$ instead based on orthogonality arguments at steady state. To the contrary, we find that eliminating the contribution of the perpendicular component of the update to the weight norm leads to little change to the training dynamics. Instead, we derive that decoupled weight decay $\propto γ^2$ results in stable weight norm based on the simple assumption that updates become independent of the weights at steady state, regardless of the nature of the optimizer. Based on the same assumption, we derive and empirically verify that the Total Update Contribution (TUC) of a minibatch under the Scion optimizer is better characterized by the momentum-dependent effective learning rate whose optimal value transfers and we show that decoupled weight decay $\propto γ^2$ leads to stable weight and gradient norms and allows us to better control the training dynamics and improve the model performance.

Correction of Decoupled Weight Decay

TL;DR

To resolve how decoupled weight decay shapes training dynamics, the paper analyzes weight-norm evolution under steady-state independence and argues that decoupled weight decay should scale with γ^2. It shows the perpendicular component of the update contributes negligibly to weight norms, and derives that TUC is governed by a momentum-dependent effective learning rate γ_eff. The proposed ScionC uses corrected weight decay and momentum-normalized updates, with experiments on ViT-S/16 and Modded-NanoGPT demonstrating more stable weight/gradient norms and competitive accuracy. The work clarifies how to set weight decay in decoupled schemes and highlights practical strategies for stable training across momentum regimes.

Abstract

Decoupled weight decay, solely responsible for the performance advantage of AdamW over Adam, has long been set to proportional to learning rate without questioning. Some researchers have recently challenged such assumption and argued that decoupled weight decay should be set instead based on orthogonality arguments at steady state. To the contrary, we find that eliminating the contribution of the perpendicular component of the update to the weight norm leads to little change to the training dynamics. Instead, we derive that decoupled weight decay results in stable weight norm based on the simple assumption that updates become independent of the weights at steady state, regardless of the nature of the optimizer. Based on the same assumption, we derive and empirically verify that the Total Update Contribution (TUC) of a minibatch under the Scion optimizer is better characterized by the momentum-dependent effective learning rate whose optimal value transfers and we show that decoupled weight decay leads to stable weight and gradient norms and allows us to better control the training dynamics and improve the model performance.

Paper Structure

This paper contains 18 sections, 17 equations, 13 figures, 5 tables, 2 algorithms.

Figures (13)

  • Figure 1: Training a ViT-S/16 with "Renormalized" AdamW results in negligible differences in top-1 val. accuracy (77.15 vs. 77.45 for the $\gamma = 0.001$, $\lambda = 0.1$ AdamW baseline), weight norm, and gradient norm throughout the training process. Notice the suppression of weight norm and surge of gradient norm towards the end of the cosine learning rate decay, characteristic of AdamW. Except using the PyTorch Inception crop with crop scale lower bound $a_{min} = 0.2$, the setup is identical to beyer2022betterplainvitbaselines.
  • Figure 2: ImageNet-1k top-1 val. accuracy of simple ViT-S/16 trained for 90 epochs with momentum $\alpha \in [0.01, 0.5]$ plotted along the maximum learning rate $\gamma$ (left) vs. maximum steady-state effective learning rate $\gamma_\mathrm{eff}$ (right) for the non-Sign parameters at the start of cosine decay. The optimal learning rate $\gamma$ increases with momentum $\alpha$ while the optimal effective momentum $\gamma_\mathrm{eff}$ is within a factor of 2 across the momentum values and well within the granularity of the sweep. Weight and gradient norms are kept stable and comparable with ScionC (Algorithm \ref{['alg:ScionC']} with maximum learning rate $\gamma_L = 0.2$, momentum $\alpha = 0.1$, weight decay coefficient $\lambda_L = 0.004$ for the Sign layer and $C^2_l = 1.1875$ for other parameters) for these experiments.
  • Figure 3: Simple ViT-S/16 trained on ImageNet-1k for 90 epochs with ScionC (Algorithm \ref{['alg:ScionC']} with maximum learning rate $\gamma_L = 0.2$, momentum $\alpha = 0.1$, weight decay coefficient $\lambda_L = 0.004$ for the Sign layer and maximum learning rate $\gamma = 0.01$, $C^2_l = 1.1875$ for other parameters) and baseline cosine learning rate decay vs. the equivalent momentum scheduling. For the momentum scheduling experiments $\alpha$ increases from $0.1$ to $\alpha_\mathrm{max} = \{0.2, 0.5, 1.0\}$ s.t. the effective learning rate $\gamma_\mathrm{eff}$ matches that of the cosine learning rate baseline until $\alpha_\mathrm{max}$ is reached. The models converge to the same top-1 val. accuracy up till $\alpha_\mathrm{max} = 0.5$ where the weight norm approximation starts to break down.
  • Figure 4: Training 124M Modded-NanoGPT on FineWeb-Edu-100B, Scion vs. ScionC. $\lambda \propto \gamma$ scaling of ScionC results in more stable weight norm, gradient norm, and Spectral norms. The final validation loss is 2.846 for Scion and 2.838 for ScionC.
  • Figure 5: Training ViT-S/16 on ImageNet-1k, AdamW (upper) vs. AdamC (lower). $\lambda \propto \gamma$ scaling of AdamC results in more stable weight and gradient norms. Note that the model does not seem to be in steady state even after 300 epochs.
  • ...and 8 more figures