Table of Contents
Fetching ...

Replaying pre-training data improves fine-tuning

Suhas Kotha, Percy Liang

TL;DR

This work surprisingly finds that replaying the generic data during fine-tuning can actually improve performance on the (less related) target task.

Abstract

To obtain a language model for a target domain (e.g. math), the current paradigm is to pre-train on a vast amount of generic web text and then fine-tune on the relatively limited amount of target data. Typically, generic data is only mixed in during fine-tuning to prevent catastrophic forgetting of the generic domain. We surprisingly find that replaying the generic data during fine-tuning can actually improve performance on the (less related) target task. Concretely, in a controlled pre-training environment with 4M target tokens, 4B total tokens, and 150M parameter models, generic replay increases target data efficiency by up to $1.87\times$ for fine-tuning and $2.06\times$ for mid-training. We further analyze data schedules that introduce target data during pre-training and find that replay helps more when there is less target data present in pre-training. We demonstrate the success of replay in practice for fine-tuning 8B parameter models, improving agentic web navigation success by $4.5\%$ and Basque question-answering accuracy by $2\%$.

Replaying pre-training data improves fine-tuning

TL;DR

This work surprisingly finds that replaying the generic data during fine-tuning can actually improve performance on the (less related) target task.

Abstract

To obtain a language model for a target domain (e.g. math), the current paradigm is to pre-train on a vast amount of generic web text and then fine-tune on the relatively limited amount of target data. Typically, generic data is only mixed in during fine-tuning to prevent catastrophic forgetting of the generic domain. We surprisingly find that replaying the generic data during fine-tuning can actually improve performance on the (less related) target task. Concretely, in a controlled pre-training environment with 4M target tokens, 4B total tokens, and 150M parameter models, generic replay increases target data efficiency by up to for fine-tuning and for mid-training. We further analyze data schedules that introduce target data during pre-training and find that replay helps more when there is less target data present in pre-training. We demonstrate the success of replay in practice for fine-tuning 8B parameter models, improving agentic web navigation success by and Basque question-answering accuracy by .
Paper Structure (59 sections, 24 figures, 1 table)

This paper contains 59 sections, 24 figures, 1 table.

Figures (24)

  • Figure 1: Replaying the generic distribution can improve target performance. Standard fine-tuning trains on all target data (blue) after all generic data (purple). We find that replaying generic data during fine-tuning can surprisingly improve performance on the target domain, both for fine-tuning and mid-training (e.g. $1.87\times$ and $2.06\times$ for FineMath, respectively). We find that replay is most helpful when there is less target data present during pre-training.
  • Figure 2: Data scaling law for reference algorithm. We run a reference training strategy with different target data budgets. To estimate how effectively an algorithm is using the data, we invert the reference strategy's scaling law to recover "effective data" for this loss and compare the data efficiency improvement between two strategies. All of our data efficiency estimates only need to interpolate this scaling law.
  • Figure 3: Controlled fine-tuning visualization. We systematically explore the benefit of replaying generic data while fine-tuning on the target data. On the right, we show standard fine-tuning for $T$ steps where $\gamma$ fraction of the steps are on target data. On the left, we show fine-tuning with replay fraction $\rho$ (where we shorten pre-training to keep the total number of steps fixed). We use (independently tuned) cosine learning rate schedules for each stage, with an optimizer state reset between the stages to simulate standard practice for fine-tuning open-weight models.
  • Figure 4: Replay improves loss on target data. We show that across our target domains, the correct amount of replay (starred points) beats the no replay baseline (dotted line). Though data distributions closer to pre-training (Flan) can tolerate more replay compared to further domains (StarCoder), the loss improvement is relatively constant across domains.
  • Figure 5: Tuning learning rate cooldown. We tune how long we should cool down the learning rate for WSD. The above plot shows the optimal cooldown period is between 0.05 and 0.1; we use 0.1 for consistency across domains and being fair to changing data schedules.
  • ...and 19 more figures