Early Weight Averaging meets High Learning Rates for LLM Pre-training
Sunny Sanyal, Atula Neerkaje, Jean Kaddour, Abhishek Kumar, Sujay Sanghavi
TL;DR
The paper tackles the cost of pre-training LLMs by introducing Latest Weight Averaging (LAWA), a method that performs weight averaging early in training with high learning rates to accelerate convergence and improve generalization. LAWA samples distant checkpoints along the training trajectory and averages them, effectively acting as a surrogate for LR decay while also harnessing ensemble-like diversity. Through extensive experiments on nanoGPT-2 (125M–770M) and Pythia (1B–12B), LAWA consistently outperforms traditional training, EMA, and SWA, reducing training steps while achieving better validation perplexity and zero-shot performance. Preliminary diffusion-model results suggest LAWA’s benefits extend to non-language-models, indicating broad applicability for accelerating and stabilizing training under large-batch, high-LR regimes.
Abstract
Training Large Language Models (LLMs) incurs significant cost; hence, any strategy that accelerates model convergence is helpful. In this paper, we investigate the ability of a simple idea checkpoint averaging along the trajectory of a training run to improve both convergence and generalization quite early on during training. Here we show that models trained with high learning rates observe higher gains due to checkpoint averaging. Furthermore, these gains are amplified when checkpoints are sampled with considerable spacing in training steps. Our training recipe outperforms conventional training and popular checkpoint averaging baselines such as exponential moving average (EMA) and stochastic moving average (SWA). We evaluate our training recipe by pre-training LLMs, where high learning rates are inherently preferred due to extremely large batch sizes. Specifically, we pre-trained nanoGPT-2 models of varying sizes, small (125M), medium (335M), and large (770M)on the OpenWebText dataset, comprised of 9B tokens. Additionally, we present results for publicly available Pythia LLMs, ranging from 1B to 12B, which were trained on the PILE-deduped dataset containing 207B tokens.
