Table of Contents
Fetching ...

Simple and Scalable Strategies to Continually Pre-train Large Language Models

Adam Ibrahim, Benjamin Thérien, Kshitij Gupta, Mats L. Richter, Quentin Anthony, Timothée Lesort, Eugene Belilovsky, Irina Rish

TL;DR

This work tackles the compute cost of updating large language models by proposing a simple continual pre-training recipe that combines learning-rate re-warming, re-decaying, and replay of previous data. Across weak and strong distribution shifts and at 405M and 10B parameter scales, this approach matches or closely approaches the performance of full re-training on all data, while using substantially less compute. The authors also introduce infinite learning rate schedules to mitigate forgetting and enable smoother transitions across datasets, showing potential for even more scalable continual updates. Overall, the results demonstrate a practical path to keep LLMs up-to-date with new data without the prohibitive cost of re-training from scratch.

Abstract

Large language models (LLMs) are routinely pre-trained on billions of tokens, only to start the process over again once new data becomes available. A much more efficient solution is to continually pre-train these models, saving significant compute compared to re-training. However, the distribution shift induced by new data typically results in degraded performance on previous data or poor adaptation to the new data. In this work, we show that a simple and scalable combination of learning rate (LR) re-warming, LR re-decaying, and replay of previous data is sufficient to match the performance of fully re-training from scratch on all available data, as measured by the final loss and the average score on several language model (LM) evaluation benchmarks. Specifically, we show this for a weak but realistic distribution shift between two commonly used LLM pre-training datasets (English$\rightarrow$English) and a stronger distribution shift (English$\rightarrow$German) at the $405$M parameter model scale with large dataset sizes (hundreds of billions of tokens). Selecting the weak but realistic shift for larger-scale experiments, we also find that our continual learning strategies match the re-training baseline for a 10B parameter LLM. Our results demonstrate that LLMs can be successfully updated via simple and scalable continual learning strategies, matching the re-training baseline using only a fraction of the compute. Finally, inspired by previous work, we propose alternatives to the cosine learning rate schedule that help circumvent forgetting induced by LR re-warming and that are not bound to a fixed token budget.

Simple and Scalable Strategies to Continually Pre-train Large Language Models

TL;DR

This work tackles the compute cost of updating large language models by proposing a simple continual pre-training recipe that combines learning-rate re-warming, re-decaying, and replay of previous data. Across weak and strong distribution shifts and at 405M and 10B parameter scales, this approach matches or closely approaches the performance of full re-training on all data, while using substantially less compute. The authors also introduce infinite learning rate schedules to mitigate forgetting and enable smoother transitions across datasets, showing potential for even more scalable continual updates. Overall, the results demonstrate a practical path to keep LLMs up-to-date with new data without the prohibitive cost of re-training from scratch.

Abstract

Large language models (LLMs) are routinely pre-trained on billions of tokens, only to start the process over again once new data becomes available. A much more efficient solution is to continually pre-train these models, saving significant compute compared to re-training. However, the distribution shift induced by new data typically results in degraded performance on previous data or poor adaptation to the new data. In this work, we show that a simple and scalable combination of learning rate (LR) re-warming, LR re-decaying, and replay of previous data is sufficient to match the performance of fully re-training from scratch on all available data, as measured by the final loss and the average score on several language model (LM) evaluation benchmarks. Specifically, we show this for a weak but realistic distribution shift between two commonly used LLM pre-training datasets (EnglishEnglish) and a stronger distribution shift (EnglishGerman) at the M parameter model scale with large dataset sizes (hundreds of billions of tokens). Selecting the weak but realistic shift for larger-scale experiments, we also find that our continual learning strategies match the re-training baseline for a 10B parameter LLM. Our results demonstrate that LLMs can be successfully updated via simple and scalable continual learning strategies, matching the re-training baseline using only a fraction of the compute. Finally, inspired by previous work, we propose alternatives to the cosine learning rate schedule that help circumvent forgetting induced by LR re-warming and that are not bound to a fixed token budget.
Paper Structure (50 sections, 4 equations, 13 figures, 14 tables)

This paper contains 50 sections, 4 equations, 13 figures, 14 tables.

Figures (13)

  • Figure 1: Continual pre-training decreases computational costs of updating the model while maintaining similar final validation and average evaluation performance. We report results for the Pile $\cup$ SlimPajama(SP)/German(Ger.) baseline model trained on the union of both datasets which we consider to be an upper bound on performance. We also report performance for two continually pre-trained models. "PT on Pile" starts from a pre-trained Pile checkpoint and only uses learning rate re-warming and re-decaying, while "Replay (PT on Pile)" re-warms the learning rate, re-decays it, and uses 5% replay for SlimPajama and 25% replay for German. We observe that the combination of LR re-warming, re-decaying, and replay allows our continually pre-trained model to attain similar average performance to the baseline model while requiring substantially less compute. We note that this setting assumes that a pre-trained model is available (e.g., via HuggingFace hub or an in-house model designed to be continually pre-trained).
  • Figure 2: Linear warmup and cosine annealing schedule. For illustration purposes, the schedule uses linear warmup for 10% of training iterations. However, most works have a duration between 0.1% and 0.5% of training steps zhao2023survey.
  • Figure 3: The effect of linear warmup for weak and strong distribution shifts. (a),(b) and (c),(d) have the same legends respectively, shown in the right figures. We train 405M parameters models following a linear warmup and cosine decay schedule with varying linear warmup durations: 0%,0.5%,1%, and 2% of training iterations. Each learning rate schedule decays to $0.1\eta_\textit{max}$ by the end of training based on the size of the dataset. We report results for the first 50B tokens of training. In the settings explored, we observe that the duration of the warm-up phase does not appear to be impactful when continuing to pre-train.
  • Figure 4: The effect of re-warming and re-decaying the learning rate on adaptation and forgetting. We consider two constant baselines and three models that re-warm and re-decay. One baseline continues training from $\eta_\textit{min}$ of pre-training ($3\cdot10^{-5}$) while the other warms up to $\eta_\textit{max}$ from pre-training ($3\cdot10^{-4}$). For the models that re-warm and re-decay we vary $\eta_\textit{max} \in \{1.5\cdot10^{-4},3\cdot10^{-4},6\cdot10^{-4}\}$. All models except the $\eta_\textit{min}$ baseline use linear warmup for $1\%$ training iteration. The non-baseline models cosine decay the learning to reach $0.1\cdot\eta_\textit{max}$ by the end of training. We observe that re-warming and re-decaying the learning rate is needed to best adapt to the new dataset. Small increases or decreases in $\eta_\textit{max}$ allow to trade-off between more or less adaptation. A stronger distribution shift seems to be a catalyst for both forgetting and adaptation.
  • Figure 5: The effect of replay at 405M scale for weak and strong distribution shifts. We report Pile validation loss (left) and SlimPajama/German validation (right top/bottom) during training. Each model is trained from a checkpoint pre-trained on $300$B tokens of Pile. The blue dotted line reports the final validation loss for models trained on Pile$\cup$SlimPajama or Pile$\cup$German data, totaling 600B and 500B tokens datasets respectively. We observe that replay significantly reduces forgetting across both shifts, however, the stronger shift requires more replay to mitigate forgetting to the same extent.
  • ...and 8 more figures