Table of Contents
Fetching ...

Efficient Continual Pre-training by Mitigating the Stability Gap

Yiduo Guo, Jie Fu, Huishuai Zhang, Dongyan Zhao, Yikang Shen

TL;DR

The paper investigates a stability gap during continual pre-training of LLMs in new domains, showing initial drops in domain-task performance followed by recovery. It reframes the phenomenon through a plasticity-stability gradient lens and proposes three practical strategies—subset-based multi-epoch training, high-quality data selection, and data-mixture scheduling—to mitigate the gap under fixed compute budgets. Empirical results on OpenLlama-3B and Llama-3-8B demonstrate improved medical-domain performance and competitive benchmarks relative to GPT-4, including a model suitable for open, on-device deployment. The work provides actionable guidance for efficient domain adaptation of open-source LLMs and contributes a high-quality open-source medical LLM baseline for the community.

Abstract

Continual pre-training has increasingly become the predominant approach for adapting Large Language Models (LLMs) to new domains. This process involves updating the pre-trained LLM with a corpus from a new domain, resulting in a shift in the training distribution. To study the behavior of LLMs during this shift, we measured the model's performance throughout the continual pre-training process. we observed a temporary performance drop at the beginning, followed by a recovery phase, a phenomenon known as the "stability gap," previously noted in vision models classifying new classes. To address this issue and enhance LLM performance within a fixed compute budget, we propose three effective strategies: (1) Continually pre-training the LLM on a subset with a proper size for multiple epochs, resulting in faster performance recovery than pre-training the LLM on a large corpus in a single epoch; (2) Pre-training the LLM only on high-quality sub-corpus, which rapidly boosts domain performance; and (3) Using a data mixture similar to the pre-training data to reduce distribution gap. We conduct various experiments on Llama-family models to validate the effectiveness of our strategies in both medical continual pre-training and instruction tuning. For example, our strategies improve the average medical task performance of the OpenLlama-3B model from 36.2% to 40.7% with only 40% of the original training budget and enhance the average general task performance without causing forgetting. Furthermore, we apply our strategies to the Llama-3-8B model. The resulting model, Llama-3-Physician, achieves the best medical performance among current open-source models, and performs comparably to or even better than GPT-4 on several medical benchmarks. We release our models at \url{https://huggingface.co/YiDuo1999/Llama-3-Physician-8B-Instruct}.

Efficient Continual Pre-training by Mitigating the Stability Gap

TL;DR

The paper investigates a stability gap during continual pre-training of LLMs in new domains, showing initial drops in domain-task performance followed by recovery. It reframes the phenomenon through a plasticity-stability gradient lens and proposes three practical strategies—subset-based multi-epoch training, high-quality data selection, and data-mixture scheduling—to mitigate the gap under fixed compute budgets. Empirical results on OpenLlama-3B and Llama-3-8B demonstrate improved medical-domain performance and competitive benchmarks relative to GPT-4, including a model suitable for open, on-device deployment. The work provides actionable guidance for efficient domain adaptation of open-source LLMs and contributes a high-quality open-source medical LLM baseline for the community.

Abstract

Continual pre-training has increasingly become the predominant approach for adapting Large Language Models (LLMs) to new domains. This process involves updating the pre-trained LLM with a corpus from a new domain, resulting in a shift in the training distribution. To study the behavior of LLMs during this shift, we measured the model's performance throughout the continual pre-training process. we observed a temporary performance drop at the beginning, followed by a recovery phase, a phenomenon known as the "stability gap," previously noted in vision models classifying new classes. To address this issue and enhance LLM performance within a fixed compute budget, we propose three effective strategies: (1) Continually pre-training the LLM on a subset with a proper size for multiple epochs, resulting in faster performance recovery than pre-training the LLM on a large corpus in a single epoch; (2) Pre-training the LLM only on high-quality sub-corpus, which rapidly boosts domain performance; and (3) Using a data mixture similar to the pre-training data to reduce distribution gap. We conduct various experiments on Llama-family models to validate the effectiveness of our strategies in both medical continual pre-training and instruction tuning. For example, our strategies improve the average medical task performance of the OpenLlama-3B model from 36.2% to 40.7% with only 40% of the original training budget and enhance the average general task performance without causing forgetting. Furthermore, we apply our strategies to the Llama-3-8B model. The resulting model, Llama-3-Physician, achieves the best medical performance among current open-source models, and performs comparably to or even better than GPT-4 on several medical benchmarks. We release our models at \url{https://huggingface.co/YiDuo1999/Llama-3-Physician-8B-Instruct}.
Paper Structure (43 sections, 8 figures, 5 tables)

This paper contains 43 sections, 8 figures, 5 tables.

Figures (8)

  • Figure 1: The performance comparison between our model (Llama-3-physician) and other baselines involves reporting the ratio of each model's task performance to the best performance of that task among all models.
  • Figure 2: (a) reports the models' average medical performance during the medical continual pre-training process. (b) illustrates the models' average medical perplexity (PPL) during the medical continual pre-training process. (c) shows the Pythia model's average common-sense task performance when we continually pre-train it on the new Refined-Web datasets.
  • Figure 3: (a) shows the OpenLLaMa's average common-sense task performance during medical continual pre-training. (b) illustrates the OpenLlama model's relative parameter update during the medical continual pre-training process. We report the average weight relative update of weights in the top 5 layers and the bottom 5 layers. We also report the rate between the two average numbers.
  • Figure 4: (a) reports the average medical performance during the medical continual pre-training process. The baseline is pre-training the OpenLlama-3B model with 50b medical tokens with one epoch. '5b Random' is pre-training the LLM with 5b tokens randomly selected from the 50b medical tokens for 5 epochs. '5b HQ' is pre-training the LLM with the highest quality (HQ) 5b tokens of the 50b medical tokens for 5 epochs. (b) shows the average medical performance across 5 epochs. (c) illustrates the average commonsense task performance across 5 epochs.
  • Figure 5: (a) reports the performance of TinyLlama-1.1B across multiple epochs. All these experiments use our strategies with different pre-training learning rates. (b) reports the performance of OpenLlama-3B across multiple epochs. All of the experiments in (a) and (b) use our strategies with different pre-training learning rates. (c) reports the performance of OpenLlama-3B across multiple epochs with different training subset sizes $S$. To collect the pre-training corpus with different sizes, we first rank all samples of the 50 billion medical tokens based on the perplexity calculated by the trained KenLM (see Sec. \ref{['sec.identify']}). Then, we select the first $S$ billion tokens with the lowest perplexity. For all experiments here, we report the average task performance of PubMedQA, MedMCQA, MMLU-medical-genetics, and MedQA-4-Option tasks.
  • ...and 3 more figures