Table of Contents
Fetching ...

Early Weight Averaging meets High Learning Rates for LLM Pre-training

Sunny Sanyal, Atula Neerkaje, Jean Kaddour, Abhishek Kumar, Sujay Sanghavi

TL;DR

The paper tackles the cost of pre-training LLMs by introducing Latest Weight Averaging (LAWA), a method that performs weight averaging early in training with high learning rates to accelerate convergence and improve generalization. LAWA samples distant checkpoints along the training trajectory and averages them, effectively acting as a surrogate for LR decay while also harnessing ensemble-like diversity. Through extensive experiments on nanoGPT-2 (125M–770M) and Pythia (1B–12B), LAWA consistently outperforms traditional training, EMA, and SWA, reducing training steps while achieving better validation perplexity and zero-shot performance. Preliminary diffusion-model results suggest LAWA’s benefits extend to non-language-models, indicating broad applicability for accelerating and stabilizing training under large-batch, high-LR regimes.

Abstract

Training Large Language Models (LLMs) incurs significant cost; hence, any strategy that accelerates model convergence is helpful. In this paper, we investigate the ability of a simple idea checkpoint averaging along the trajectory of a training run to improve both convergence and generalization quite early on during training. Here we show that models trained with high learning rates observe higher gains due to checkpoint averaging. Furthermore, these gains are amplified when checkpoints are sampled with considerable spacing in training steps. Our training recipe outperforms conventional training and popular checkpoint averaging baselines such as exponential moving average (EMA) and stochastic moving average (SWA). We evaluate our training recipe by pre-training LLMs, where high learning rates are inherently preferred due to extremely large batch sizes. Specifically, we pre-trained nanoGPT-2 models of varying sizes, small (125M), medium (335M), and large (770M)on the OpenWebText dataset, comprised of 9B tokens. Additionally, we present results for publicly available Pythia LLMs, ranging from 1B to 12B, which were trained on the PILE-deduped dataset containing 207B tokens.

Early Weight Averaging meets High Learning Rates for LLM Pre-training

TL;DR

The paper tackles the cost of pre-training LLMs by introducing Latest Weight Averaging (LAWA), a method that performs weight averaging early in training with high learning rates to accelerate convergence and improve generalization. LAWA samples distant checkpoints along the training trajectory and averages them, effectively acting as a surrogate for LR decay while also harnessing ensemble-like diversity. Through extensive experiments on nanoGPT-2 (125M–770M) and Pythia (1B–12B), LAWA consistently outperforms traditional training, EMA, and SWA, reducing training steps while achieving better validation perplexity and zero-shot performance. Preliminary diffusion-model results suggest LAWA’s benefits extend to non-language-models, indicating broad applicability for accelerating and stabilizing training under large-batch, high-LR regimes.

Abstract

Training Large Language Models (LLMs) incurs significant cost; hence, any strategy that accelerates model convergence is helpful. In this paper, we investigate the ability of a simple idea checkpoint averaging along the trajectory of a training run to improve both convergence and generalization quite early on during training. Here we show that models trained with high learning rates observe higher gains due to checkpoint averaging. Furthermore, these gains are amplified when checkpoints are sampled with considerable spacing in training steps. Our training recipe outperforms conventional training and popular checkpoint averaging baselines such as exponential moving average (EMA) and stochastic moving average (SWA). We evaluate our training recipe by pre-training LLMs, where high learning rates are inherently preferred due to extremely large batch sizes. Specifically, we pre-trained nanoGPT-2 models of varying sizes, small (125M), medium (335M), and large (770M)on the OpenWebText dataset, comprised of 9B tokens. Additionally, we present results for publicly available Pythia LLMs, ranging from 1B to 12B, which were trained on the PILE-deduped dataset containing 207B tokens.
Paper Structure (32 sections, 14 figures, 4 tables, 1 algorithm)

This paper contains 32 sections, 14 figures, 4 tables, 1 algorithm.

Figures (14)

  • Figure 1: Across all model sizes, LAWA achieves faster convergence and generalizes better in comparison to original pretraining run and other baseline averaging schemes. Validation loss on OpenWebText with 70K training steps; (a) GPT2-small (125M) with Original is 2.963, EMA is 2.949, SWA is 2.952 and LAWA (ours-best) is 2.917, (b) GPT2-medium (355M) with Original is 2.855, EMA is 2.845, SWA is 2.837 and LAWA (ours-best) is 2.819, and (c) GPT2-large (770M) with Original is 2.977, EMA is 2.968, SWA is 2.961 and LAWA (ours-best) is 2.908.
  • Figure 2: We compare two independently trained nanoGPT-2 (125M) models with LR $=[\mathsf{6 \times 10^{-3}}, \mathsf{6 \times 10^{-4}}]$ on OpenWebText data. (a) Pre-training curve with and without LAWA. LLMs trained with higher LR observes higher gain due to LAWA. (b) The model trained with a high LR generalizes poorly compared to its counterpart trained with low/tuned LR. (c) The generalization gap caused by the high LR is effectively mitigated by LAWA.
  • Figure 3: LAWA illustration: Given weights $\mathsf{W}_1, \mathsf{W}_2, ... \mathsf{W_k}$ from a high LR trajectory separated by k-stepsize ($\nu$, Algorithm \ref{['alg:lawa_algo']}), LAWA computes $\mathsf{W}_{avg}$ at a given step.
  • Figure 4: LAWA saves significant amount of GPU hours compared to original training. We compare the savings in GPU hours as a function of increase in final perplexity, i.e. perplexity achieved at 141K training step by the original checkpoint. This plot is created using a held out set from the training subset PILE-philosophy papers.
  • Figure 5: LAWA speeds up convergence for Pythia-1B on subset of tasks from the original pretraining dataset i.e. PILE. We present the original and the LAWA training trajectories for 3 different tasks from PILE namely philpapers, bookcorpus2 and youtube subtitles.
  • ...and 9 more figures