Table of Contents
Fetching ...

An Efficient Rehearsal Scheme for Catastrophic Forgetting Mitigation during Multi-stage Fine-tuning

Andrew Bai, Chih-Kuan Yeh, Cho-Jui Hsieh, Ankur Taly

TL;DR

Catastrophic forgetting is a central challenge in sequential fine-tuning of foundational models. The authors introduce mix-cd, a compute-constrained rehearsal scheme that prioritizes collateral-damage samples—those correctly predicted by the base model but forgotten after fine-tuning—by estimating their density without extra inferences. They define a per-bin collateral damage rate $oldsymbol{lpha}_k$ and update it across iterations using previously mixed-in data, enabling targeted rehearsal without additional forward passes. Partitions based on prior loss, auxiliary information, and their combinations guide sampling to regions most prone to forgetting. Across three tasks (NLI, QA, and multilingual translation), mix-cd consistently outperforms baselines on the Pareto frontier, demonstrating a practical, lightweight approach to forgetting mitigation with broad applicability to compute-constrained continual learning.

Abstract

Incrementally fine-tuning foundational models on new tasks or domains is now the de facto approach in NLP. A known pitfall of this approach is the \emph{catastrophic forgetting} of prior knowledge that happens during fine-tuning. A common approach to alleviate such forgetting is to rehearse samples from prior tasks during fine-tuning. Several existing works assume a fixed memory buffer to store prior task examples, while relying on inferences (forward passes) with the model at hand for choosing examples for rehearsal from the buffer. However, given the increasing computational cost of model inference, and decreasing cost of data storage, we focus on the setting to rehearse samples with a fixed computational budget instead of a fixed memory budget. We propose a sampling scheme, \texttt{\bf mix-cd}, that prioritizes rehearsal of ``collateral damage'' samples, which are samples predicted correctly by the prior model but forgotten by the incrementally tuned one. The crux of our scheme is a procedure to efficiently estimate the density of collateral damage samples without incurring additional model inferences. Our approach is computationally efficient, easy to implement, and outperforms several leading continual learning methods in compute-constrained settings. All the code will be publicly available at https://github.com/jybai/mix-cd-rehearsal.

An Efficient Rehearsal Scheme for Catastrophic Forgetting Mitigation during Multi-stage Fine-tuning

TL;DR

Catastrophic forgetting is a central challenge in sequential fine-tuning of foundational models. The authors introduce mix-cd, a compute-constrained rehearsal scheme that prioritizes collateral-damage samples—those correctly predicted by the base model but forgotten after fine-tuning—by estimating their density without extra inferences. They define a per-bin collateral damage rate and update it across iterations using previously mixed-in data, enabling targeted rehearsal without additional forward passes. Partitions based on prior loss, auxiliary information, and their combinations guide sampling to regions most prone to forgetting. Across three tasks (NLI, QA, and multilingual translation), mix-cd consistently outperforms baselines on the Pareto frontier, demonstrating a practical, lightweight approach to forgetting mitigation with broad applicability to compute-constrained continual learning.

Abstract

Incrementally fine-tuning foundational models on new tasks or domains is now the de facto approach in NLP. A known pitfall of this approach is the \emph{catastrophic forgetting} of prior knowledge that happens during fine-tuning. A common approach to alleviate such forgetting is to rehearse samples from prior tasks during fine-tuning. Several existing works assume a fixed memory buffer to store prior task examples, while relying on inferences (forward passes) with the model at hand for choosing examples for rehearsal from the buffer. However, given the increasing computational cost of model inference, and decreasing cost of data storage, we focus on the setting to rehearse samples with a fixed computational budget instead of a fixed memory budget. We propose a sampling scheme, \texttt{\bf mix-cd}, that prioritizes rehearsal of ``collateral damage'' samples, which are samples predicted correctly by the prior model but forgotten by the incrementally tuned one. The crux of our scheme is a procedure to efficiently estimate the density of collateral damage samples without incurring additional model inferences. Our approach is computationally efficient, easy to implement, and outperforms several leading continual learning methods in compute-constrained settings. All the code will be publicly available at https://github.com/jybai/mix-cd-rehearsal.
Paper Structure (48 sections, 7 equations, 9 figures, 1 algorithm)

This paper contains 48 sections, 7 equations, 9 figures, 1 algorithm.

Figures (9)

  • Figure 1: Examples of collateral damage in prior language translations after fine-tuning on Danish-to-English.
  • Figure 2: Preliminary observations suggest that while random rehearsal of prior data helps mitigate collateral damage, upweighting collateral damage samples in the prior data distribution benefits the joint performance on both tasks even more.
  • Figure 3: Pareto frontiers of prior and fine-tune performance. Curves closer to the top right are preferable.
  • Figure 4: Proportion comparison of collateral damage per sample between random uniform and mix-cd across different mix ratios. mix-cd consistently samples twice or more collateral damage for rehearsal compared to random uniform, which explains the superior performance.
  • Figure 5: Ablation study on different partitions for the data distribution. Partitions with higher KL divergence in collateral damage ratios between bins (e.g. loss and answerable partitions) provide better signals for prioritizing collateral damage samples.
  • ...and 4 more figures

Theorems & Definitions (1)

  • Remark