Why Gradients Rapidly Increase Near the End of Training
Aaron Defazio
TL;DR
The paper investigates why gradient norms spike near the end of long-duration training for LLMs, attributing the blow-up to an unintended interaction between weight decay, normalization layers, and learning-rate schedules. It develops a theory based on steady-state gradient-to-weight ratios for normalized layers and shows that weight decay naturally balances these ratios across layers, enabling a single global learning rate. A simple, theory-motivated correction to weight decay is proposed, decoupling the steady-state target from the learning-rate schedule and yielding variants AdamC/SGDC that improve stability and reduce loss. Experimental results on ImageNet and large-language-model pretraining validate the theory, showing reduced gradient blow-up and lower losses with the corrected decoupled weight decay.
Abstract
During long-duration Large Language Model (LLM) training runs the gradient norm increases rapidly near the end of training. In this short note, we show that this increase is due to an unintended interaction between weight decay, normalization layers, and the learning rate schedule. We propose a simple correction that fixes this behavior while also resulting in lower loss values throughout training.
