Table of Contents
Fetching ...

Scaling Laws for Forgetting When Fine-Tuning Large Language Models

Damjan Kalajdzievski

TL;DR

This work investigates forgetting that occurs when fine-tuning pre-trained large language models using parameter-efficient techniques like LoRA. It introduces a cross-entropy-based forgetting metric and demonstrates that forgetting is strongly driven by a linear relation to fine-tuning loss and by shifted power-law scaling in the number of tunable parameters and update steps. The authors fit joint laws for forgetting and fine-tuning loss, showing consistent exponents across datasets, and reveal that forgetting also degrades generation-related capabilities such as reasoning on ARC and safety-alignment on AdvBench. The study emphasizes the need for developing forgetting-mitigation strategies in fine-tuning to preserve pre-trained capabilities while enabling task adaptation.

Abstract

We study and quantify the problem of forgetting when fine-tuning pre-trained large language models (LLMs) on a downstream task. We find that parameter-efficient fine-tuning (PEFT) strategies, such as Low-Rank Adapters (LoRA), still suffer from catastrophic forgetting. In particular, we identify a strong inverse linear relationship between the fine-tuning performance and the amount of forgetting when fine-tuning LLMs with LoRA. We further obtain precise scaling laws that show forgetting increases as a shifted power law in the number of parameters fine-tuned and the number of update steps. We also examine the impact of forgetting on knowledge, reasoning, and the safety guardrails trained into Llama 2 7B chat. Our study suggests that forgetting cannot be avoided through early stopping or by varying the number of parameters fine-tuned. We believe this opens up an important safety-critical direction for future research to evaluate and develop fine-tuning schemes which mitigate forgetting

Scaling Laws for Forgetting When Fine-Tuning Large Language Models

TL;DR

This work investigates forgetting that occurs when fine-tuning pre-trained large language models using parameter-efficient techniques like LoRA. It introduces a cross-entropy-based forgetting metric and demonstrates that forgetting is strongly driven by a linear relation to fine-tuning loss and by shifted power-law scaling in the number of tunable parameters and update steps. The authors fit joint laws for forgetting and fine-tuning loss, showing consistent exponents across datasets, and reveal that forgetting also degrades generation-related capabilities such as reasoning on ARC and safety-alignment on AdvBench. The study emphasizes the need for developing forgetting-mitigation strategies in fine-tuning to preserve pre-trained capabilities while enabling task adaptation.

Abstract

We study and quantify the problem of forgetting when fine-tuning pre-trained large language models (LLMs) on a downstream task. We find that parameter-efficient fine-tuning (PEFT) strategies, such as Low-Rank Adapters (LoRA), still suffer from catastrophic forgetting. In particular, we identify a strong inverse linear relationship between the fine-tuning performance and the amount of forgetting when fine-tuning LLMs with LoRA. We further obtain precise scaling laws that show forgetting increases as a shifted power law in the number of parameters fine-tuned and the number of update steps. We also examine the impact of forgetting on knowledge, reasoning, and the safety guardrails trained into Llama 2 7B chat. Our study suggests that forgetting cannot be avoided through early stopping or by varying the number of parameters fine-tuned. We believe this opens up an important safety-critical direction for future research to evaluate and develop fine-tuning schemes which mitigate forgetting
Paper Structure (15 sections, 9 equations, 8 figures, 1 table)

This paper contains 15 sections, 9 equations, 8 figures, 1 table.

Figures (8)

  • Figure 1: Generation examples of the pre-trained model, and a model fine-tuned with LoRA on a dataset of recent news articles (see Section \ref{['section:experimental_setup']} for a description of the dataset). These generations exemplify the updated knowledge, forgotten knowledge (ARC dataset arc), and forgotten safety/alignment behavior (AdvBench dataset harmful) resulting from fine-tuning.
  • Figure 2: Fine-tuning performance vs Forgetting on OpenOrca (Left) and News (Right) datasets. The inverse linear relationship between forgetting and fine-tuning is shown in black, while evaluations for fine-tuning runs with different numbers of parameters are scatter plotted in color. We obtain a strong fit with coefficients of determination .9450 and .9736 for OpenOrca and News respectively. This shows that forgetting depends primarily on fine-tuning loss, and paints a pessimistic picture that if one uses conventional fine-tuning approaches to achieve a certain level of fine-tuning dataset performance, forgetting is unavoidable by means of early stopping or tuning a fewer (or greater) number of parameters.
  • Figure 3: Forgetting and fine-tuning loss trajectories and fit curves for varying ranks (Left:) OpenOrca dataset. (Right:) News dataset. Our fit functions for $\mathcal{L}_{\text{f}}(P,N),\mathcal{L}_{\text{ft}}(P,N)$ are plotted with solid lines, and the dotted lines are the data trajectories. Note the consistent relationships between fine-tuning or forgetting and $P,N$ across very different types of fine-tuning data. The fit for forgetting as a function of $P$ and $N$ takes into account some of the extra spread in forgetting relative to $\mathcal{L}_{\text{f}}(\mathcal{L}_{\text{ft}})$, and thus improves the fit from an $R^2$ of .9450, .9736 to .9598, .9769 on OpenOrca and News respectively.
  • Figure 4: Forgetting on the (ARC) dataset arc. Shown are checkpoints of the smallest model trained, evaluated every 50 steps of training, while fine-tuned on the datasets OpenOrca (Left) and News (Right). Pre-trained base model accuracy is in red, fine-tuned model accuracy is in blue, and accuracy vs the base model predictions is in green. We see the performance of the model deteriorates substantially when fine-tuning on the News dataset. In contrast, we see that the reasoning capability of the model is not as dramatically affected while training on OpenOrca. This is intuitive since OpenOrca largely contains data explicitly exhibiting reasoning, whereas the News dataset does not. We note that on the OpenOrca models, accuracy with respect to the pre-trained model's prediction shows the forgetting, while usual accuracy does not. This is due to the fine-tuned model making a different set of errors than the base model.
  • Figure 5: Example of the rank 8 OpenOrca model forgetting safety tuning on AdvBench harmful after fine-tuning (Left). In this example the base pre-trained model correctly generated a refusal behaviour (Right). Bold text in square brackets is editorial. See appendix section \ref{['section:observationforget']} for additional examples.
  • ...and 3 more figures