Table of Contents
Fetching ...

Straight to Zero: Why Linearly Decaying the Learning Rate to Zero Works Best for LLMs

Shane Bergsma, Nolan Dey, Gurpreet Gosal, Gavia Gray, Daria Soboleva, Joel Hestness

TL;DR

This work addresses the question of which LR schedules yield minimum loss in LLM pre-training when training is performed at compute-optimal data sizes. It develops an extended EMA view of AdamW and a bias–variance training model to explain why linear decay-to-zero (D2Z) often outperforms traditional schedules like 10x cosine, especially as tokens-per-parameter increase. Through large-scale experiments across model scales (111M–1.7B), datasets, and batch configurations, the study shows D2Z provides consistent improvements in training, validation, and downstream tasks, with substantial compute savings in compute-bound regimes. The results imply that D2Z not only reduces gradient noise effectively but also preserves early bias reduction, offering robust hyperparameter transfer under μP and suggesting broader implications for high-TPP training efficiency and schedule design.

Abstract

LLMs are commonly trained with a learning rate (LR) warmup, followed by cosine decay to 10% of the maximum (10x decay). In a large-scale empirical study, we show that under an optimal peak LR, a simple linear decay-to-zero (D2Z) schedule consistently outperforms other schedules when training at compute-optimal dataset sizes. D2Z is superior across a range of model sizes, batch sizes, datasets, and vocabularies. Benefits increase as dataset size increases. Leveraging a novel interpretation of AdamW as an exponential moving average of weight updates, we show how linear D2Z optimally balances the demands of early training (moving away from initial conditions) and late training (averaging over more updates in order to mitigate gradient noise). In experiments, a 610M-parameter model trained for 80 tokens-per-parameter (TPP) using D2Z achieves lower loss than when trained for 200 TPP using 10x decay, corresponding to an astonishing 60% compute savings. Models such as Llama2-7B, trained for 286 TPP with 10x decay, could likely have saved a majority of compute by training with D2Z.

Straight to Zero: Why Linearly Decaying the Learning Rate to Zero Works Best for LLMs

TL;DR

This work addresses the question of which LR schedules yield minimum loss in LLM pre-training when training is performed at compute-optimal data sizes. It develops an extended EMA view of AdamW and a bias–variance training model to explain why linear decay-to-zero (D2Z) often outperforms traditional schedules like 10x cosine, especially as tokens-per-parameter increase. Through large-scale experiments across model scales (111M–1.7B), datasets, and batch configurations, the study shows D2Z provides consistent improvements in training, validation, and downstream tasks, with substantial compute savings in compute-bound regimes. The results imply that D2Z not only reduces gradient noise effectively but also preserves early bias reduction, offering robust hyperparameter transfer under μP and suggesting broader implications for high-TPP training efficiency and schedule design.

Abstract

LLMs are commonly trained with a learning rate (LR) warmup, followed by cosine decay to 10% of the maximum (10x decay). In a large-scale empirical study, we show that under an optimal peak LR, a simple linear decay-to-zero (D2Z) schedule consistently outperforms other schedules when training at compute-optimal dataset sizes. D2Z is superior across a range of model sizes, batch sizes, datasets, and vocabularies. Benefits increase as dataset size increases. Leveraging a novel interpretation of AdamW as an exponential moving average of weight updates, we show how linear D2Z optimally balances the demands of early training (moving away from initial conditions) and late training (averaging over more updates in order to mitigate gradient noise). In experiments, a 610M-parameter model trained for 80 tokens-per-parameter (TPP) using D2Z achieves lower loss than when trained for 200 TPP using 10x decay, corresponding to an astonishing 60% compute savings. Models such as Llama2-7B, trained for 286 TPP with 10x decay, could likely have saved a majority of compute by training with D2Z.

Paper Structure

This paper contains 48 sections, 11 equations, 35 figures, 6 tables.

Figures (35)

  • Figure 1: A 610M model trained for 80 TPP with $\hbox{Linear}$-$\hbox{D2Z}$ has better train (and validation) loss than when trained for 200 TPP with $\hbox{Linear}$-$\hbox{10}\times$.
  • Figure 2: LR schedules and their update-combination duals: Each LR schedule, $\eta_t$ (left) and weight decay, $\lambda$, implies a weighted combination of weight updates, with combination coefficients $c_{t,i}$ (right, log-scale) giving the contribution of $i$th update to parameters $\theta_t$ at step $t$ (111M scale, coefficients at final step). The more sudden the drop in LR, the less emphasis on valuable later updates, perhaps explaining why $\hbox{Step}$ underperforms $\hbox{Cosine}$ and $\hbox{Cosine}$ underperforms $\hbox{Linear}$ decay.
  • Figure 3: Bias & variance in LLM pre-training: as training duration increases (higher TPP), the importance of variance reduction --- and having a lower LR --- increases. With no decay ($\hbox{Constant}$, \ref{['fig:const_low']}, \ref{['fig:const_high']}), the optimal peak LR must drop significantly lower, neglecting bias reduction. With $\hbox{D2Z}$ (\ref{['fig:dtoz_low']}, \ref{['fig:dtoz_high']}), periods of bias and variance reduction can both be enjoyed without large shifts in peak LR.
  • Figure 4: HP influence vs. TPP: Higher TPP means higher gradient noise; LR decay & weight decay settings thus increase in importance with TPP.
  • Figure 5: Comparing decay schedules across TPP (111M scale): As TPP increases, $\hbox{Linear}$-$\hbox{D2Z}$ outperforms $\hbox{10}\times$, especially at the proxy-tuned peak LR (red lines).
  • ...and 30 more figures