Table of Contents
Fetching ...

Chained Tuning Leads to Biased Forgetting

Megan Ung, Alicia Sun, Samuel J. Bell, Bhaktipriya Radharapu, Levent Sagun, Adina Williams

TL;DR

This work analyzes how chained finetuning of large language models affects retention of safety alignments, revealing that task order and group demographics modulate forgetting, a phenomenon termed biased forgetting. The authors introduce a formal framework and a new metric to quantify forgetting across tasks and groups, and they demonstrate that safety-related tasks are more prone to forgetting when placed after capability tasks, with significant disparities across demographic groups. They show that model minima curvature (as proxied by the Hessian spectral radius) correlates with forgetting and that wider minima mitigate forgetting, suggesting directions for curbing safety loss through optimization. Finally, they propose mitigation via data replay in a third finetuning stage and show that even small amounts of safety data (~5%) can substantially restore safety performance with minimal cost to capability tasks, highlighting practical strategies for safer continual learning in LLMs.

Abstract

Large language models (LLMs) are often fine-tuned for use on downstream tasks, though this can degrade capabilities learned during previous training. This phenomenon, often referred to as catastrophic forgetting, has important potential implications for the safety of deployed models. In this work, we first show that models trained on downstream tasks forget their safety tuning to a greater extent than models trained in the opposite order. Second, we show that forgetting disproportionately impacts safety information about certain groups. To quantify this phenomenon, we define a new metric we term biased forgetting. We conduct a systematic evaluation of the effects of task ordering on forgetting and apply mitigations that can help the model recover from the forgetting observed. We hope our findings can better inform methods for chaining the finetuning of LLMs in continual learning settings to enable training of safer and less toxic models.

Chained Tuning Leads to Biased Forgetting

TL;DR

This work analyzes how chained finetuning of large language models affects retention of safety alignments, revealing that task order and group demographics modulate forgetting, a phenomenon termed biased forgetting. The authors introduce a formal framework and a new metric to quantify forgetting across tasks and groups, and they demonstrate that safety-related tasks are more prone to forgetting when placed after capability tasks, with significant disparities across demographic groups. They show that model minima curvature (as proxied by the Hessian spectral radius) correlates with forgetting and that wider minima mitigate forgetting, suggesting directions for curbing safety loss through optimization. Finally, they propose mitigation via data replay in a third finetuning stage and show that even small amounts of safety data (~5%) can substantially restore safety performance with minimal cost to capability tasks, highlighting practical strategies for safer continual learning in LLMs.

Abstract

Large language models (LLMs) are often fine-tuned for use on downstream tasks, though this can degrade capabilities learned during previous training. This phenomenon, often referred to as catastrophic forgetting, has important potential implications for the safety of deployed models. In this work, we first show that models trained on downstream tasks forget their safety tuning to a greater extent than models trained in the opposite order. Second, we show that forgetting disproportionately impacts safety information about certain groups. To quantify this phenomenon, we define a new metric we term biased forgetting. We conduct a systematic evaluation of the effects of task ordering on forgetting and apply mitigations that can help the model recover from the forgetting observed. We hope our findings can better inform methods for chaining the finetuning of LLMs in continual learning settings to enable training of safer and less toxic models.

Paper Structure

This paper contains 36 sections, 6 equations, 14 figures, 3 tables.

Figures (14)

  • Figure 1: Task ordering experimental set up. Finetuning on a safety task first and then a capability task, and vice versa.
  • Figure 2: Relative Forgetting (%) of the capability task and the safety task for each pair of tasks. Overall, there is more forgetting of the safety task (orange) than forgetting of the capability task (blue).
  • Figure 3: Forgetting by groups in ToxiGenQA (a, b, c) and BBQ (d, e, f) followed by finetuning on a capability task (ARC, CQA, CQA2). All experiments use learning rate $1e-5$ and batch size 16. The blue dotted vertical line denotes the average forgetting on the safety task (TQA or BBQ).
  • Figure 4: (a) Minima curvature (i.e. approx. spectral radius of the Hessian) obtained after training on the first task. (b) Mean downstream forgetting as a function of first task curvature. Curvature explains a significant proportion of downstream forgetting. Error bars are standard deviation over three training runs.
  • Figure 5: Forgetting when varying first task learning rate for different safety task$\rightarrow$capability task sequence.
  • ...and 9 more figures