Table of Contents
Fetching ...

Unlocking the Power of Function Vectors for Characterizing and Mitigating Catastrophic Forgetting in Continual Instruction Tuning

Gangwei Jiang, Caigao Jiang, Zhaoyi Li, Siqiao Xue, Jun Zhou, Linqi Song, Defu Lian, Ying Wei

TL;DR

This paper investigates catastrophic forgetting (CF) in continual instruction tuning of large language models by introducing Function Vector (FV), a compact latent representation of task-specific functions derived from causal attention heads. It demonstrates that CF mainly arises from shifts in the latent function activation $P_M(\theta|x)$ rather than overwriting prior task function mappings $P_M(y|x,\theta)$, and proposes a FV-guided training method that stabilizes $\theta_T$ via FV consistency and FV-guided KL losses. Empirical results across multiple benchmarks and models show that FV-guided training significantly mitigates forgetting in general and in-context abilities while preserving plasticity for new tasks, with strong correlations between FV dynamics and forgetting. The work positions FV as a mechanistic interpretability tool to analyze and mitigate forgetting, and provides a practical training design that can be integrated with existing continual learning methods.

Abstract

Catastrophic forgetting (CF) poses a significant challenge in machine learning, where a model forgets previously learned information upon learning new tasks. Despite the advanced capabilities of Large Language Models (LLMs), they continue to face challenges with CF during continual learning. The majority of existing research focuses on analyzing forgetting patterns through a singular training sequence, thereby overlooking the intricate effects that diverse tasks have on model behavior. Our study explores CF across various settings, discovering that model forgetting is influenced by both the specific training tasks and the models themselves. To this end, we interpret forgetting by examining the function vector (FV), a compact representation of functions in LLMs, offering a model-dependent indicator for the occurrence of CF. Through theoretical and empirical analyses, we demonstrated that CF in LLMs primarily stems from biases in function activation rather than the overwriting of task processing functions. Leveraging these insights, we propose a novel function vector guided training methodology, incorporating a regularization technique to stabilize the FV and mitigate forgetting. Empirical tests on four benchmarks confirm the effectiveness of our proposed training method, substantiating our theoretical framework concerning CF and model function dynamics. We plan to make our code publicly accessible in the near future.

Unlocking the Power of Function Vectors for Characterizing and Mitigating Catastrophic Forgetting in Continual Instruction Tuning

TL;DR

This paper investigates catastrophic forgetting (CF) in continual instruction tuning of large language models by introducing Function Vector (FV), a compact latent representation of task-specific functions derived from causal attention heads. It demonstrates that CF mainly arises from shifts in the latent function activation rather than overwriting prior task function mappings , and proposes a FV-guided training method that stabilizes via FV consistency and FV-guided KL losses. Empirical results across multiple benchmarks and models show that FV-guided training significantly mitigates forgetting in general and in-context abilities while preserving plasticity for new tasks, with strong correlations between FV dynamics and forgetting. The work positions FV as a mechanistic interpretability tool to analyze and mitigate forgetting, and provides a practical training design that can be integrated with existing continual learning methods.

Abstract

Catastrophic forgetting (CF) poses a significant challenge in machine learning, where a model forgets previously learned information upon learning new tasks. Despite the advanced capabilities of Large Language Models (LLMs), they continue to face challenges with CF during continual learning. The majority of existing research focuses on analyzing forgetting patterns through a singular training sequence, thereby overlooking the intricate effects that diverse tasks have on model behavior. Our study explores CF across various settings, discovering that model forgetting is influenced by both the specific training tasks and the models themselves. To this end, we interpret forgetting by examining the function vector (FV), a compact representation of functions in LLMs, offering a model-dependent indicator for the occurrence of CF. Through theoretical and empirical analyses, we demonstrated that CF in LLMs primarily stems from biases in function activation rather than the overwriting of task processing functions. Leveraging these insights, we propose a novel function vector guided training methodology, incorporating a regularization technique to stabilize the FV and mitigate forgetting. Empirical tests on four benchmarks confirm the effectiveness of our proposed training method, substantiating our theoretical framework concerning CF and model function dynamics. We plan to make our code publicly accessible in the near future.

Paper Structure

This paper contains 29 sections, 9 equations, 10 figures, 8 tables, 1 algorithm.

Figures (10)

  • Figure 1: Performance heatmap during continual learning of 2 different sequences on Llama2-7b-chat and Llama3-8b-chat. The numbers above the heatmap indicate the baseline performance of each task, with the performance of the pre-trained model for general testing (e.g., in a-(II) 66.3 is the score of Commonsense on original Llama2-7b-chat) and performance right after completing current task for trained task testing (e.g., in a-(I) 28.2 is the score of T2 on Llama2-7b-chat post 2-th task training). The numbers on the heatmap show the percentage change relative to the baseline (e.g., in a-(I) first column 47 indicates the score at 38.6*47%). Main conclusion: (1) Learning generation tasks (a/c) vs. classification tasks (b/d) lead to more forgetting.; (2) Forgetting may reduce naturally (a-(II)/d-(II)); (3) Forgetting is model-dependent (a/b vs. c/d).
  • Figure 2: The shifts in function vector with 0/5-shot performance during tuning. The bar chart corresponding to the left y-axis shows the similarity of function vectors to their initial state. The line graph corresponding to the right y-axis depicts the model's Rouge-L metric on test data. Main conclusion: A significant correlation between performance (line data) and FV similarity (bar data). The correlation plots with more data point are provided in Fig. \ref{['fig:app:corr']}.
  • Figure 3: Intervention results on fine-tuned model. '+ Source FV' and '- Target FV' refers to Evidence I and Evidence II, respectively. Main conclusion: intervention with related function vector mitigating forgetting.
  • Figure 4: The shifts in function vector with 0/5-shot performance with function vector guided training. Main conclusion: FVG prevents the shift in FV (yellow bar) and thus mitigates forgetting (orange line).
  • Figure 5: Illustration of causal pathway to forgetting. In (a), the pre-trained model is expressed in a latent variable assumption. It assumes task $T_0$ establishes a predictive pathway (shown in orange) that aligns well with the task (high probability with $\theta_{T_0}^0$). In (b), it shows the model after learning a new task $T_1$ without regularization, which will necessarily update the function attention heads, i.e., $P_M(\theta|x)$, (shown in red blocks), producing new function vectors $\theta^1_{T_0}$ and $\theta^1_{T_1}$ that are biased toward $T_1$. These shifts in function vectors lead to a derailed predictive pathway (shown in purple) with erroneous predictions for task $T_0$; in other words, forgetting of $T_0$ occurs. In summary, the modifications in $P_M(\theta|x)$ rather than $P_M(y|x,\theta)$ are the primary driving force behind forgetting.
  • ...and 5 more figures