Table of Contents
Fetching ...

How to set AdamW's weight decay as you scale model and dataset size

Xi Wang, Laurence Aitchison

TL;DR

This work reframes AdamW as an exponential moving average (EMA) of recent weight updates, linking the EMA timescale to the product of learning rate and weight decay via $1/\tau_{\text{iter}} = \eta \lambda$ and defining $\tau_{\text{epoch}} = \tau_{\text{iter}}/M$. It demonstrates that the optimal EMA timescale is largely invariant to changes in model and dataset size, yielding practical weight-decay scaling rules: under a fixed learning rate, larger datasets favor smaller $\lambda$, while under μP LR scaling, larger models require larger $\lambda$. The paper validates these rules across ResNet-18, Vision Transformers, and NanoGPT pretraining, revealing that naïve μP scaling of the learning rate breaks for AdamW unless $\lambda$ is scaled to keep the EMA timescale constant. By aligning weight-decay scaling with the EMA view, the authors restore stable LR transfer and provide a principled approach to hyperparameter scaling in large-scale training scenarios.

Abstract

The scaling of the optimal AdamW weight decay hyperparameter with model and dataset size is critical as we seek to build larger models, but is poorly understood. We show that weights learned by AdamW can be understood as an exponential moving average (EMA) of recent updates. This gives critical insights for how to set the weight decay in AdamW, and how the weight decay should scale with model and dataset size. In particular, the key hyperparameter for an exponential moving average is the EMA timescale. Intuitively, the EMA timescale can be understood as the number of recent iterations the EMA averages over. We find that the optimal timescale, measured in epochs, is roughly constant as we change model and dataset size. Moreover, given a learning rate, there is a one-to-one mapping from the EMA timescale to the weight decay hyperparameter. Thus, if the optimal EMA timescale is constant, that implies that as the dataset size increases, the optimal weight decay should fall and as the model size increases, the optimal weight decay should increase (if we follow the muP recommendation for scaling the learning rate). We validate these scaling rules on ResNet-18 and Vision Transformers trained on CIFAR-10 and ImageNet, and on NanoGPT pre-training on OpenWebText. Finally, we found that as training progresses, muP's learning rate scaling breaks down for AdamW unless weight decay is scaled appropriately.

How to set AdamW's weight decay as you scale model and dataset size

TL;DR

This work reframes AdamW as an exponential moving average (EMA) of recent weight updates, linking the EMA timescale to the product of learning rate and weight decay via and defining . It demonstrates that the optimal EMA timescale is largely invariant to changes in model and dataset size, yielding practical weight-decay scaling rules: under a fixed learning rate, larger datasets favor smaller , while under μP LR scaling, larger models require larger . The paper validates these rules across ResNet-18, Vision Transformers, and NanoGPT pretraining, revealing that naïve μP scaling of the learning rate breaks for AdamW unless is scaled to keep the EMA timescale constant. By aligning weight-decay scaling with the EMA view, the authors restore stable LR transfer and provide a principled approach to hyperparameter scaling in large-scale training scenarios.

Abstract

The scaling of the optimal AdamW weight decay hyperparameter with model and dataset size is critical as we seek to build larger models, but is poorly understood. We show that weights learned by AdamW can be understood as an exponential moving average (EMA) of recent updates. This gives critical insights for how to set the weight decay in AdamW, and how the weight decay should scale with model and dataset size. In particular, the key hyperparameter for an exponential moving average is the EMA timescale. Intuitively, the EMA timescale can be understood as the number of recent iterations the EMA averages over. We find that the optimal timescale, measured in epochs, is roughly constant as we change model and dataset size. Moreover, given a learning rate, there is a one-to-one mapping from the EMA timescale to the weight decay hyperparameter. Thus, if the optimal EMA timescale is constant, that implies that as the dataset size increases, the optimal weight decay should fall and as the model size increases, the optimal weight decay should increase (if we follow the muP recommendation for scaling the learning rate). We validate these scaling rules on ResNet-18 and Vision Transformers trained on CIFAR-10 and ImageNet, and on NanoGPT pre-training on OpenWebText. Finally, we found that as training progresses, muP's learning rate scaling breaks down for AdamW unless weight decay is scaled appropriately.
Paper Structure (30 sections, 1 theorem, 45 equations, 19 figures)

This paper contains 30 sections, 1 theorem, 45 equations, 19 figures.

Key Result

Theorem 1

Consider two AdamW optimizers with different learning rates, weight decays, initialization scales and epsilons, ($\eta_t, \lambda, \sigma, \epsilon$ vs. $\eta_t', \lambda', \sigma', \epsilon'$). Take $w_t$ to be the parameters learned by the first optimizer after the $t$th optimization step, and $w_ where $\xi$ is random noise (e.g. IID Gaussian). Consider a scale-invariant network, in the sense t

Figures (19)

  • Figure 1: The optimal $\tau_\text{epoch}$ transfers across dataset sizes. We trained ResNet-18 (A) and ViT (B) on subsets of downsampled ImageNet of various sizes (lines of different colors) under different weight decay (dots on the lines) under a fixed batch size of 100. An initial learning rate of $10^{-3}$ is used with cosine decay scheduling. The performance metrics after 100 epochs are plotted against the weight decay $\lambda$ and the corresponding timescale $\tau_\text{epoch}$ computed with the initial learning rate. The dashed lines show the optimal $\tau_\text{epoch}$ at a subset size of $320,000$: In both models, the optimal $\tau_\text{epoch}$ is fairly stable across dataset sizes whereas the optimal $\lambda$ decreases dramatically as dataset size grows.
  • Figure 2: For LLM pre-training, optimal weight decay shifts with dataset sizes but $\tau_\text{epoch}$ transfers. We trained a 124M NanoGPT on subsets of OpenWebText with different sizes (line colors) for 4 epochs under a fixed batch size, an initial learning rate of $6 \times 10^{-4}$ and various $\lambda$ (dots on lines, varied in powers of 2). As dataset size increases, weight decay that gives optimal validation loss (red crosses) decreases, whereas $\tau_\text{epoch}$ (Eq. \ref{['eq:tauepoch']}) is stable across scales.
  • Figure 3: The optimal $\lambda$ increases with model size whereas the optimal timescale is more stable. We trained ResNet-18 on a subset of ImageNet 32x32, with varying width factor $s$ (lines of different colors) under a fixed base learning rate $10^{-3}$ with varying weight decay (dots on the lines) and plotted the metrics after 50 epochs vs. weight decay strength. The top row scales the hyperparameters using the direct µ P approach (Eq. \ref{['eq:standard_mup']}; i.e. fixed $\lambda$), while the bottom row scales the hyperparameters to ensure $\tau_{\text{iter} }$ is fixed (Eq. \ref{['eq:improved_mup']}; $\lambda$ increases with model size). Note that as $\eta_\text{base} = 10^{-3}$ is fixed, there is a direct relationship between the optimal $\lambda_\text{base}$ and the optimal $\tau_\text{iter;base}$.
  • Figure 4: For LLM pre-training, optimal weight decay increases with model size whereas the optimal timescales transfer. We trained 8-layer GPTs on OpenWebText, with various widths (line colors), under two $\eta_\text{base}$, with actual learning rates for each width scaled following µ P with $s=\tfrac{1024}{\textrm{width}}$. We plot the final validation loss against various $\lambda_\text{base}$ (dots on lines, varied by powers of $2$). We align all lines to the rightmost point for better visibility, where the numbers in the brackets denote the actual optimal validation loss at each width. We considered training under two parameterizations, if we keep $\lambda$ decoupled from $s$ then the optimal $\lambda_\text{base}$ shows a clear shift with model size (top rows), if $\lambda$ increases with $s$, optimal $\lambda_\text{base}$ becomes much stable across widths.
  • Figure 5: AdamW breaks the learning rate scaling of µ P. Following the experiment setting in yang2022tensor, we trained a ResNet-18 with varying width factor $s$ (lines of different colors) under various base learning rates $\eta_\text{base}$ (x-axis) on CIFAR-10 (A) and a $320,000$ samples subset of ImageNet 32x32 (B). We then plotted the metrics after 200 (for CIFAR-10) and 50 (for ImageNet) epochs against $\eta_\text{base}$. The top row scales the hyperparameters using the direct µ P approach (Eq. \ref{['eq:standard_mup']}; i.e. fixed $\lambda$), while the bottom row scales the hyperparameters to ensure $\tau_{\text{iter} }$ is fixed (Eq. \ref{['eq:improved_mup']}; $\lambda$ increases with model size). In both datasets, the direct approach breaks the stability of optimal $\eta_\text{base}$ in terms of test metrics due to changing the timescale whereas our scaling allows for consistent $\eta_\text{base}$ across model sizes.
  • ...and 14 more figures

Theorems & Definitions (1)

  • Theorem 1