Table of Contents
Fetching ...

Continual Memorization of Factoids in Language Models

Howard Chen, Jiayi Geng, Adithya Bhaskar, Dan Friedman, Danqi Chen

TL;DR

This work introduces continual memorization, a two-stage fine-tuning framework in which a model must memorize factoids learned in stage one and retain them after stage two with potentially conflicting data. It reveals that forgetting is especially severe when stage two is another factoid dataset and that replay alone cannot fully mitigate this forgetting. The authors propose REMIX, a simple yet effective data-mixing strategy that incorporates random word sequences and generic pretraining data into training stages, yielding substantial retention gains over baselines and revealing that memorized facts are stored in earlier and more diversified layers. Through analytical and probing methods (e.g., Logit Lens), REMIX is shown to alter memorization dynamics, enabling easier recall and manipulation of facts with minimal impact on downstream performance. The results offer practical guidance for preserving knowledge in LMs during continual updating and open avenues for further study on stability and safety of memorized knowledge.

Abstract

As new knowledge rapidly accumulates, language models (LMs) with pretrained knowledge quickly become obsolete. A common approach to updating LMs is fine-tuning them directly on new knowledge. However, recent studies have shown that fine-tuning for memorization may be ineffective in storing knowledge or may exacerbate hallucinations. In this work, we introduce a setting we call continual memorization, where a model must memorize and retain a set of factoids through multiple stages of fine-tuning on subsequent datasets. We characterized the forgetting patterns through extensive experiments and show that LMs widely suffer from forgetting, especially when needing to memorize factoids in the second stage. We posit that forgetting can be alleviated by modifying training dynamics: (1) protecting the memorization process when learning factoids or (2) reducing interference from subsequent training stages. Intriguingly, we find that mixing randomly generated word sequences or generic data sampled from pretraining corpora at different training stages effectively mitigates forgetting REMIX: Random and Generic Data Mixing). REMIX can recover performance from severe forgetting, outperforming replay methods and other continual learning baselines. We analyze how REMIX influences the learning process and find that robust memorization follows a distinct pattern: the model stores factoids in earlier layers than usual and diversifies the layers that retain them, which results in easier recall and manipulate of the learned factoids.

Continual Memorization of Factoids in Language Models

TL;DR

This work introduces continual memorization, a two-stage fine-tuning framework in which a model must memorize factoids learned in stage one and retain them after stage two with potentially conflicting data. It reveals that forgetting is especially severe when stage two is another factoid dataset and that replay alone cannot fully mitigate this forgetting. The authors propose REMIX, a simple yet effective data-mixing strategy that incorporates random word sequences and generic pretraining data into training stages, yielding substantial retention gains over baselines and revealing that memorized facts are stored in earlier and more diversified layers. Through analytical and probing methods (e.g., Logit Lens), REMIX is shown to alter memorization dynamics, enabling easier recall and manipulation of facts with minimal impact on downstream performance. The results offer practical guidance for preserving knowledge in LMs during continual updating and open avenues for further study on stability and safety of memorized knowledge.

Abstract

As new knowledge rapidly accumulates, language models (LMs) with pretrained knowledge quickly become obsolete. A common approach to updating LMs is fine-tuning them directly on new knowledge. However, recent studies have shown that fine-tuning for memorization may be ineffective in storing knowledge or may exacerbate hallucinations. In this work, we introduce a setting we call continual memorization, where a model must memorize and retain a set of factoids through multiple stages of fine-tuning on subsequent datasets. We characterized the forgetting patterns through extensive experiments and show that LMs widely suffer from forgetting, especially when needing to memorize factoids in the second stage. We posit that forgetting can be alleviated by modifying training dynamics: (1) protecting the memorization process when learning factoids or (2) reducing interference from subsequent training stages. Intriguingly, we find that mixing randomly generated word sequences or generic data sampled from pretraining corpora at different training stages effectively mitigates forgetting REMIX: Random and Generic Data Mixing). REMIX can recover performance from severe forgetting, outperforming replay methods and other continual learning baselines. We analyze how REMIX influences the learning process and find that robust memorization follows a distinct pattern: the model stores factoids in earlier layers than usual and diversifies the layers that retain them, which results in easier recall and manipulate of the learned factoids.

Paper Structure

This paper contains 54 sections, 6 equations, 12 figures, 10 tables.

Figures (12)

  • Figure 1: The continual memorization setting. In stage 1 (red box), a pretrained model $\mathcal{M}_0$ is trained to convergence on a factoid dataset $D_A$ to obtain model $\mathcal{M}_A$. In stage 2, model $\mathcal{M}_A$ is further trained on either a factoid dataset or a non-factoid dataset (blue box) to obtain model $\mathcal{M}_B$. The final model $\mathcal{M}_B$ is evaluated on the training examples $D_A$ in stage 1. REMIX: mixing random words and pretraining data into training during stages 1 and 2 alleviates forgetting.
  • Figure 2: Intuition behind each mixing strategy. In general, forgetting occurs when $\nabla \mathcal{L}(\theta; D_A)^T \nabla \mathcal{L}(\theta; D_B) < 0$ (angle between red and blue arrows larger than 90 degree). The model goes from $\theta_0$ to $\theta_A$ in stage 1 (gray arrow), and arrives at $\theta_B$ in stage 2 (blue arrow). The translucent blobs represent low-loss region for each dataset. No Mixing: the opposing angle between the red and blue arrows contributes to forgetting. Mixing at Stage 1: the mixing data $D_M$ protects memorization by shifting the model parameters to reduce the angle between the red and blue arrows while converging to a low loss on $D_A$. Mixing at Stage 2: mixing data $D_M$ reduces the interference of $D_B$ by lowering the angle between blue and red arrows.
  • Figure 3: Replay results averaged across all $D_B$.
  • Figure 4: Left: probing on Key-Value Recall using Logit Lens. x-axis: layer index. y-axis: the normalized frequency of the correct token occurring in the top-10 tokens probed at each layer. $\%$ following each legend shows the accuracy on each stage 1 task. Right: layer of first occurence (LoF) aggregated over 100 examples. the mean, standard deviation and overall accuracy on KVR, PopQA and TriviaQA. Lower mean in LoF and higher STD correlates with better performance.
  • Figure 5: $3$-stage continual memorization setting. $B=*$ refers to the stage 2 task, and $C=*$ refers to the stage 3 task. y-axis refers the accuracy (%) on Key-Value Recall. We use Random mixing at stage 1, K-Pile mixing at stage 2 for WebQA, No Mixing at stage 2 for UltraChat (UC), K-Pile mixing at stage 3 for EntityQA, and No Mixing for MATH at stage 3.
  • ...and 7 more figures