Table of Contents
Fetching ...

Demystifying Language Model Forgetting with Low-rank Example Associations

Xisen Jin, Xiang Ren

TL;DR

This work reveals that forgetting in LLM fine-tuning can be captured by low-rank associations between learned tasks and upstream examples. By modeling forgetting as an M×N matrix Z and applying matrix completion (MF/KNN) to predict which upstream examples are most forgotten, the authors enable targeted replay that mitigates forgetting with lower cost than exhaustive inference. Empirically, low-rank approximations explain forgetting well across several model families and sizes, and MF-based forgetting predictions coupled with replay yield statistically significant reductions in forgetting on held-out data. The approach offers a scalable, interpretable, and practical path for continual-learning-style mitigation in large language models.

Abstract

Large language models (LLMs) suffer from forgetting of upstream knowledge when fine-tuned. Despite efforts on mitigating forgetting, few have investigated how forgotten upstream examples are dependent on newly learned tasks. Insights on such dependencies enable efficient and targeted mitigation of forgetting. In this paper, we empirically analyze forgetting that occurs in $N$ upstream examples of language modeling or instruction-tuning after fine-tuning LLMs on one of $M$ new tasks, visualized in $M\times N$ matrices. We show that the matrices are often well-approximated with low-rank matrices, indicating the dominance of simple associations between the learned tasks and forgotten upstream examples. Leveraging the analysis, we predict forgetting of upstream examples when fine-tuning LLMs on unseen tasks with matrix completion over the empirical associations. This enables fast identification of most forgotten examples without expensive inference on the entire upstream data. Despite simplicity, the approach outperforms prior approaches that learn semantic relationships of learned tasks and upstream examples with LMs. We demonstrate the practical utility of our analysis by showing statistically significantly reduced forgetting as we upweight predicted examples for replay during fine-tuning. Code, data, and statistics collected: https://github.com/AuCson/low-rank-forgetting

Demystifying Language Model Forgetting with Low-rank Example Associations

TL;DR

This work reveals that forgetting in LLM fine-tuning can be captured by low-rank associations between learned tasks and upstream examples. By modeling forgetting as an M×N matrix Z and applying matrix completion (MF/KNN) to predict which upstream examples are most forgotten, the authors enable targeted replay that mitigates forgetting with lower cost than exhaustive inference. Empirically, low-rank approximations explain forgetting well across several model families and sizes, and MF-based forgetting predictions coupled with replay yield statistically significant reductions in forgetting on held-out data. The approach offers a scalable, interpretable, and practical path for continual-learning-style mitigation in large language models.

Abstract

Large language models (LLMs) suffer from forgetting of upstream knowledge when fine-tuned. Despite efforts on mitigating forgetting, few have investigated how forgotten upstream examples are dependent on newly learned tasks. Insights on such dependencies enable efficient and targeted mitigation of forgetting. In this paper, we empirically analyze forgetting that occurs in upstream examples of language modeling or instruction-tuning after fine-tuning LLMs on one of new tasks, visualized in matrices. We show that the matrices are often well-approximated with low-rank matrices, indicating the dominance of simple associations between the learned tasks and forgotten upstream examples. Leveraging the analysis, we predict forgetting of upstream examples when fine-tuning LLMs on unseen tasks with matrix completion over the empirical associations. This enables fast identification of most forgotten examples without expensive inference on the entire upstream data. Despite simplicity, the approach outperforms prior approaches that learn semantic relationships of learned tasks and upstream examples with LMs. We demonstrate the practical utility of our analysis by showing statistically significantly reduced forgetting as we upweight predicted examples for replay during fine-tuning. Code, data, and statistics collected: https://github.com/AuCson/low-rank-forgetting
Paper Structure (25 sections, 14 figures, 16 tables)

This paper contains 25 sections, 14 figures, 16 tables.

Figures (14)

  • Figure 1: The problem setup of analyzing the associations between learned tasks and forgotten upstream examples as we fine-tune LLMs on one of unseen new tasks. Over total $N$ upstream examples and $M$ unseen tasks, we measure and record forgetting (in red) in a $M\times N$ matrix and attempt to fit the associations with low-rank approximations. Better low-rank approximations indicate simpler associations between learned tasks and forgotten upstream examples.
  • Figure 2: An example of visualized association matrix of forgetting $Z\in \mathbb{R}^{M\times N}$ between $M=85$ learned tasks and $N=141,876$ upstream examples (from Dolma) on OLMo-7B. Each pixel $z_{ij}$ indicates forgetting (in log-perplexity increase) that occurs on an upstream example $x_j$ (in $x$-axis) after fine-tuning the model on a task $T_i$ (in $y$-axis). We annotate the domains (e.g., reddit) of upstream examples in the $x$-axis and the category of each learned task (e.g., FLAN/QA) in the $y$-axis. We include visualizations of more models and setups in Figure \ref{['fig:raw_mat_apdx']} and Figure \ref{['fig:raw_mat_bin_apdx']} in Appendix.
  • Figure 3: (a) $R^2$ or F1 of the low-rank approximations as we progressively increase the rank of the reconstruction. Forgetting is measured with log perplexity increase or exact match (EM) drop. In (b) and (c), we compare $R^2$ of approximations at a given rank $r$ across models of different types and sizes over the fixed $M=19$ tasks from Tulu and Dolly. We also report average upstream example forgetting in (d) as a reference, and include more statistics in Table \ref{['tab:simple_statistics']} in Appendix.
  • Figure 4: The training and testing setup of predicting example forgetting with association matrix completion, and their integration into example replay methods to mitigate forgetting.
  • Figure 5: Log perplexity ($\downarrow$) or Token F1 ($\uparrow$) over upstream data by replay example selection strategies. The solid horizontal lines indicate the log perplexity before fine-tuning (i.e., no forgetting). The dashed lines show the results achieved by upweighting upstream examples according to their actual forgetting after fine-tuning without replay. * and ** indicate significance of improvement ($p<0.05$ or $p<0.005$) compared to replaying random examples in paired $t$-tests on all fine-tuning tasks.
  • ...and 9 more figures