Table of Contents
Fetching ...

Scaling Laws for Forgetting during Finetuning with Pretraining Data Injection

Louis Bethune, David Grangier, Dan Busbridge, Eleonora Gualdoni, Marco Cuturi, Pierre Ablin

TL;DR

This study measures the efficiency of injecting pretraining data into the finetuning data mixture to avoid forgetting and mitigate overfitting, and derives scaling laws that quantify these two phenomena for various target domains, amounts of available target data, and model scales.

Abstract

A widespread strategy to obtain a language model that performs well on a target domain is to finetune a pretrained model to perform unsupervised next-token prediction on data from that target domain. Finetuning presents two challenges: (i) if the amount of target data is limited, as in most practical applications, the model will quickly overfit, and (ii) the model will drift away from the original model, forgetting the pretraining data and the generic knowledge that comes with it. We aim to derive scaling laws that quantify these two phenomena for various target domains, amounts of available target data, and model scales. We measure the efficiency of injecting pretraining data into the finetuning data mixture to avoid forgetting and mitigate overfitting. A key practical takeaway from our study is that injecting as little as 1% of pretraining data in the finetuning data mixture prevents the model from forgetting the pretraining set.

Scaling Laws for Forgetting during Finetuning with Pretraining Data Injection

TL;DR

This study measures the efficiency of injecting pretraining data into the finetuning data mixture to avoid forgetting and mitigate overfitting, and derives scaling laws that quantify these two phenomena for various target domains, amounts of available target data, and model scales.

Abstract

A widespread strategy to obtain a language model that performs well on a target domain is to finetune a pretrained model to perform unsupervised next-token prediction on data from that target domain. Finetuning presents two challenges: (i) if the amount of target data is limited, as in most practical applications, the model will quickly overfit, and (ii) the model will drift away from the original model, forgetting the pretraining data and the generic knowledge that comes with it. We aim to derive scaling laws that quantify these two phenomena for various target domains, amounts of available target data, and model scales. We measure the efficiency of injecting pretraining data into the finetuning data mixture to avoid forgetting and mitigate overfitting. A key practical takeaway from our study is that injecting as little as 1% of pretraining data in the finetuning data mixture prevents the model from forgetting the pretraining set.

Paper Structure

This paper contains 33 sections, 14 equations, 20 figures, 4 tables.

Figures (20)

  • Figure 1: As little as $p=1\%$ of pretraining data injection shields the model from forgetting on the pretrain dataset. The finetuning validation follows a conventional U-curve. In this paper, we always consider the models obtained at the bottom of the U-curve, that is, models with the best validation loss on the finetuning set, indicated here by a black dot. Github dataset with small model. The minimum validation loss is barely impacted by the amount of injected pretraining data $p$, and it takes more iterations to reach the minimum as $p$ increases. The loss on the training finetuning set converges to zero as training progresses since the network memorizes the dataset. The pretraining loss increases monotonically during finetuning. Injecting pretraining data has a regularizing effect that reduces overfitting and forgetting.
  • Figure 2: Generalization-memorization tradeoff. Arxiv domain. Each point corresponds to the bottom of the U-curve for a model trained on datasets of sizes 300K, 900K, 3,000K, 9,000K and 30,000K tokens with mixture parameter $p=1\%$. Forgetting is more severe when the model is small and when the finetuning dataset is big. As shown in Equation \ref{['eq:finetuningscalinglaw']}, this can be attributed to the lack of capacity of the model. More parameters are assigned to training set memorization, and fewer parameters are assigned to the pretraining set performance.
  • Figure 3: Losses as a function of the fraction of injected pretraining data $p$ on Enron emails with 900K finetuning tokens. Data mixing improves generalization when finetuning data is scarce. The diversity of the pretraining dataset biases learning toward features that exhibit higher generalization. The optimal value of $p$ depends on the domain, the dataset size, and the model size. The finetuning loss as a function of $p$ also follows a U-curve: when $p$ is too small, the model overfits too quickly and does not benefit from the regularizing effect of pretraining data injection. When $p$ is too large, the model does not see enough finetuning data to allocate it enough capacity, there is too much tension with learning from the pretraining set. As expected, increasing $p$ monotonically decreases the pretraining loss.
  • Figure 4: Overfitting and forgetting profiles for two domains. On one hand, Dm mathematics (left) is a dataset that differs a lot from the pretraining set, benefits little from more parameters, and a lot from more data. On the other hand, Wikipedia En (right) is more similar to the pretraining set, and more parameters are more beneficial than more training data. Datasets that are far from the pretraining distribution are more prone to forgetting and benefit the most from injecting pre-training data $p>0$.
  • Figure 5: Finetuning learning rate as fraction of peak pretraining learning rate. Ablation on Arxiv with XL Model (1.3B). The model's learning rate (LR) is reduced by a factor of 100 during pretraining using cosine scheduling. For finetuning, we employ a constant LR, defined as a multiple of the peak LR. Our observations indicate that setting the finetuning LR to $1/30$ times the terminal value strikes an optimal balance between effective adaptation to the finetuning dataset and stability with respect to the pretrained features. Notably, this factor of $1/30$× corresponds to the LR reached at 90% of the pretraining phase, near convergence.
  • ...and 15 more figures