Table of Contents
Fetching ...

What Will My Model Forget? Forecasting Forgotten Examples in Language Model Refinement

Xisen Jin, Xiang Ren

TL;DR

The work tackles forgetting during language model refinement and introduces forecasting of forgotten upstream examples to guide efficient replay. It uncovers logit-change transfer as a mechanism by which updating on a new example can flip the predictions of upstream pretraining instances, and proposes two forecasting paradigms: a partially interpretable logit-change-based model and a black-box representation-based model. The methods enable targeted replay of remembered-forgotten pairs to reduce forgetting, with representation-based forecasting showing robust performance across models and tasks, while logit-based forecasting excels in certain architectures like BART0. The study demonstrates practical gains in continual refinement, improved controllability, and potential for generalization to multi-step error fixing and out-of-domain settings. It also discusses computational efficiency, showing forecasting can be far cheaper than full-ground-truth forgetting inference, enabling scalable application in real-world model updates.

Abstract

Language models deployed in the wild make errors. However, simply updating the model with the corrected error instances causes catastrophic forgetting -- the updated model makes errors on instances learned during the instruction tuning or upstream training phase. Randomly replaying upstream data yields unsatisfactory performance and often comes with high variance and poor controllability. To this end, we try to forecast upstream examples that will be forgotten due to a model update for improved controllability of the replay process and interpretability. We train forecasting models given a collection of online learned examples and corresponding forgotten upstream pre-training examples. We propose a partially interpretable forecasting model based on the observation that changes in pre-softmax logit scores of pretraining examples resemble that of online learned examples, which performs decently on BART but fails on T5 models. We further show a black-box classifier based on inner products of example representations achieves better forecasting performance over a series of setups. Finally, we show that we reduce forgetting of upstream pretraining examples by replaying examples that are forecasted to be forgotten, demonstrating the practical utility of forecasting example forgetting.

What Will My Model Forget? Forecasting Forgotten Examples in Language Model Refinement

TL;DR

The work tackles forgetting during language model refinement and introduces forecasting of forgotten upstream examples to guide efficient replay. It uncovers logit-change transfer as a mechanism by which updating on a new example can flip the predictions of upstream pretraining instances, and proposes two forecasting paradigms: a partially interpretable logit-change-based model and a black-box representation-based model. The methods enable targeted replay of remembered-forgotten pairs to reduce forgetting, with representation-based forecasting showing robust performance across models and tasks, while logit-based forecasting excels in certain architectures like BART0. The study demonstrates practical gains in continual refinement, improved controllability, and potential for generalization to multi-step error fixing and out-of-domain settings. It also discusses computational efficiency, showing forecasting can be far cheaper than full-ground-truth forgetting inference, enabling scalable application in real-world model updates.

Abstract

Language models deployed in the wild make errors. However, simply updating the model with the corrected error instances causes catastrophic forgetting -- the updated model makes errors on instances learned during the instruction tuning or upstream training phase. Randomly replaying upstream data yields unsatisfactory performance and often comes with high variance and poor controllability. To this end, we try to forecast upstream examples that will be forgotten due to a model update for improved controllability of the replay process and interpretability. We train forecasting models given a collection of online learned examples and corresponding forgotten upstream pre-training examples. We propose a partially interpretable forecasting model based on the observation that changes in pre-softmax logit scores of pretraining examples resemble that of online learned examples, which performs decently on BART but fails on T5 models. We further show a black-box classifier based on inner products of example representations achieves better forecasting performance over a series of setups. Finally, we show that we reduce forgetting of upstream pretraining examples by replaying examples that are forecasted to be forgotten, demonstrating the practical utility of forecasting example forgetting.
Paper Structure (23 sections, 4 equations, 4 figures, 12 tables, 4 algorithms)

This paper contains 23 sections, 4 equations, 4 figures, 12 tables, 4 algorithms.

Figures (4)

  • Figure 1: Intriguing patterns of example forgetting while correcting prediction errors in FLAN-T5. Fixing errors in a question related to public relations flip the prediction on an example from the paraphrase detection task. It is hard to interpret forgetting solely from human understanding of textual (dis)similarity, or conflicting skills required for answering the question.
  • Figure 2: (a) Transfer of logit changes of first output tokens on an upstream pretraining example $\langle x_j, y_j \rangle$ when fixing prediction errors of an online learning example $\langle x_i, y_i \rangle$ (see Figure \ref{['fig:intro_example']} for the full texts of the example). After fixing the error, the logit scores of the tokens "not" and "duplicates" in $\langle x_i,y_i \rangle$ changes significantly, despite that their token probabilities after normalization are both close to 0. The logit change has no effect on the prediction of $\langle x_i,y_i \rangle$; however, the predictions of the upstream pretraining example $\langle x_j,y_j \rangle$ flips as the logit change partially transfers to $\langle x_j,y_j \rangle$. (b) Logit-based forecasting infers transfer of logit changes depending on the learned similarity measurement of two examples. (c) Representation-based forecasting directly predicts the binary label of forgetting based on learned similarity measurement.
  • Figure 3: F1, Precision, and Recall of representation-based (Rep), threshold-based (Thres), and trainable logit-based forecasting models averaged up to a given time step (in $x$-axis) when continually refining the LM. For all forecasting methods, recall drops over time (as more examples being forgotten), while precision remains stable. Representation-based forecasting achieves best F1 and precision at the end of the sequence.
  • Figure 4: F1, Precision, and Recall of representation-based forecasting models averaged up to a given time step (in $x$-axis) when continually refining the LM under different learning rates.