Table of Contents
Fetching ...

Why Gradients Rapidly Increase Near the End of Training

Aaron Defazio

TL;DR

The paper investigates why gradient norms spike near the end of long-duration training for LLMs, attributing the blow-up to an unintended interaction between weight decay, normalization layers, and learning-rate schedules. It develops a theory based on steady-state gradient-to-weight ratios for normalized layers and shows that weight decay naturally balances these ratios across layers, enabling a single global learning rate. A simple, theory-motivated correction to weight decay is proposed, decoupling the steady-state target from the learning-rate schedule and yielding variants AdamC/SGDC that improve stability and reduce loss. Experimental results on ImageNet and large-language-model pretraining validate the theory, showing reduced gradient blow-up and lower losses with the corrected decoupled weight decay.

Abstract

During long-duration Large Language Model (LLM) training runs the gradient norm increases rapidly near the end of training. In this short note, we show that this increase is due to an unintended interaction between weight decay, normalization layers, and the learning rate schedule. We propose a simple correction that fixes this behavior while also resulting in lower loss values throughout training.

Why Gradients Rapidly Increase Near the End of Training

TL;DR

The paper investigates why gradient norms spike near the end of long-duration training for LLMs, attributing the blow-up to an unintended interaction between weight decay, normalization layers, and learning-rate schedules. It develops a theory based on steady-state gradient-to-weight ratios for normalized layers and shows that weight decay naturally balances these ratios across layers, enabling a single global learning rate. A simple, theory-motivated correction to weight decay is proposed, decoupling the steady-state target from the learning-rate schedule and yielding variants AdamC/SGDC that improve stability and reduce loss. Experimental results on ImageNet and large-language-model pretraining validate the theory, showing reduced gradient blow-up and lower losses with the corrected decoupled weight decay.

Abstract

During long-duration Large Language Model (LLM) training runs the gradient norm increases rapidly near the end of training. In this short note, we show that this increase is due to an unintended interaction between weight decay, normalization layers, and the learning rate schedule. We propose a simple correction that fixes this behavior while also resulting in lower loss values throughout training.

Paper Structure

This paper contains 10 sections, 23 equations, 5 figures, 1 algorithm.

Figures (5)

  • Figure 1: A 120M parameter LLM training run on FineWeb-Edu, showing the behavior where the gradient norm more than doubles towards the end of training.
  • Figure 2: Gradient-to-weight ratios converge towards a steady equilibrium when training without a learning rate schedule. A 100-epoch ImageNet training run using a ResNet-50 model is shown, with each line indicating the ratio for a separate normalized layer. SGD without momentum was used, with LR $0.1$ and weight decay $0.0001$.
  • Figure 3: When a cosine learning rate schedule is used, the gradient-to-weight ratios are affected by the schedule. A 100-epoch ImageNet training run using a ResNet-50 model is shown, with each line indicating the ratio for a separate normalized layer. SGD with momentum 0.9, learning rate $0.1$ and weight decay $0.0001$ was used
  • Figure 4: 200B token training of a 120M parameter Llama 3 architecture language model trained on FineWeb-Edu
  • Figure 5: ImageNet ResNet-50 training