Table of Contents
Fetching ...

Joint Flashback Adaptation for Forgetting-Resistant Instruction Tuning

Yukun Zhao, Lingyong Yan, Zhenyang Li, Shuaiqiang Wang, Zhumin Chen, Zhaochun Ren, Dawei Yin

TL;DR

This work tackles catastrophic forgetting during continual instruction-tuning of large language models by introducing Joint Flashback Adaptation (JFA). JFA leverages a small set of flashback prompts from old tasks to constrain deviations and employs latent-task learning to share knowledge between new tasks and flashbacks, avoiding replay data and task differentiation. The framework combines a flashback divergence objective with joint task learning, where latent tasks are retrieved via KNN and integrated through LoRA-based weight increments, all optimized with gradient projection. Experiments on Vicuna-13B and Llama-3.1-8B across 1000+ tasks show improved generalization on new tasks and reduced forgetting on old tasks, with competitive performance relative to replay-based baselines on several benchmarks. The method demonstrates practical advantages for continual instruction tuning, highlighting the value of selective old-task prompts and latent task interpolation for robust knowledge transfer.

Abstract

Large language models have achieved remarkable success in various tasks. However, it is challenging for them to learn new tasks incrementally due to catastrophic forgetting. Existing approaches rely on experience replay, optimization constraints, or task differentiation, which encounter strict limitations in real-world scenarios. To address these issues, we propose Joint Flashback Adaptation. We first introduce flashbacks -- a limited number of prompts from old tasks -- when adapting to new tasks and constrain the deviations of the model outputs compared to the original one. We then interpolate latent tasks between flashbacks and new tasks to enable jointly learning relevant latent tasks, new tasks, and flashbacks, alleviating data sparsity in flashbacks and facilitating knowledge sharing for smooth adaptation. Our method requires only a limited number of flashbacks without access to the replay data and is task-agnostic. We conduct extensive experiments on state-of-the-art large language models across 1000+ instruction-following tasks, arithmetic reasoning tasks, and general reasoning tasks. The results demonstrate the superior performance of our method in improving generalization on new tasks and reducing forgetting in old tasks.

Joint Flashback Adaptation for Forgetting-Resistant Instruction Tuning

TL;DR

This work tackles catastrophic forgetting during continual instruction-tuning of large language models by introducing Joint Flashback Adaptation (JFA). JFA leverages a small set of flashback prompts from old tasks to constrain deviations and employs latent-task learning to share knowledge between new tasks and flashbacks, avoiding replay data and task differentiation. The framework combines a flashback divergence objective with joint task learning, where latent tasks are retrieved via KNN and integrated through LoRA-based weight increments, all optimized with gradient projection. Experiments on Vicuna-13B and Llama-3.1-8B across 1000+ tasks show improved generalization on new tasks and reduced forgetting on old tasks, with competitive performance relative to replay-based baselines on several benchmarks. The method demonstrates practical advantages for continual instruction tuning, highlighting the value of selective old-task prompts and latent task interpolation for robust knowledge transfer.

Abstract

Large language models have achieved remarkable success in various tasks. However, it is challenging for them to learn new tasks incrementally due to catastrophic forgetting. Existing approaches rely on experience replay, optimization constraints, or task differentiation, which encounter strict limitations in real-world scenarios. To address these issues, we propose Joint Flashback Adaptation. We first introduce flashbacks -- a limited number of prompts from old tasks -- when adapting to new tasks and constrain the deviations of the model outputs compared to the original one. We then interpolate latent tasks between flashbacks and new tasks to enable jointly learning relevant latent tasks, new tasks, and flashbacks, alleviating data sparsity in flashbacks and facilitating knowledge sharing for smooth adaptation. Our method requires only a limited number of flashbacks without access to the replay data and is task-agnostic. We conduct extensive experiments on state-of-the-art large language models across 1000+ instruction-following tasks, arithmetic reasoning tasks, and general reasoning tasks. The results demonstrate the superior performance of our method in improving generalization on new tasks and reducing forgetting in old tasks.

Paper Structure

This paper contains 27 sections, 8 equations, 7 figures, 2 tables, 1 algorithm.

Figures (7)

  • Figure 1: An ideal scenario of incrementally learning new tasks. Given an off-the-shelf model that is aligned on old tasks, we adapt the model to new tasks while preserving its capabilities on old tasks. The process does not rely on experience replay or task differentiation.
  • Figure 2: The overview of Joint Flashback Adaptation. The inputs contain prompts and targets from new tasks, and a few flashbacks from old tasks (prompts). The reference model is maintained to track the derivations. The black solid lines represent the forward pass of actual samples, while the orange solid lines represent the forward pass for joint latent task learning. The dashed gray lines indicate the backpropagation of gradients.
  • Figure 3: The comparison of joint flashback adaptation with and without joint task learning (JTL).
  • Figure 4: The performance of our method using different numbers of flashbacks per dataset. The optimal choice is marked with the vertical dashed line.
  • Figure 5: The compared performance using different $\alpha$. The optimal choice is marked with the vertical dashed line.
  • ...and 2 more figures