Table of Contents
Fetching ...

Meta-Learning Online Adaptation of Language Models

Nathan Hu, Eric Mitchell, Christopher D. Manning, Chelsea Finn

TL;DR

Knowledge in large language models becomes stale as the world changes, and naive online fine-tuning yields poor information uptake. CaMeLS introduces a bi-level, meta-learned token-weighting mechanism that reweights online losses to focus on informative tokens, using a lightweight weighting model trained with a proxy base model. The approach yields substantial improvements in knowledge retention across streaming QA datasets and transfers to much larger models, with interpretable weights that emphasize numbers and proper nouns and clear context dependence. These results suggest a practical, scalable path to keep language models up-to-date without annotated token-level supervision, broadening their applicability in dynamic information environments.

Abstract

Large language models encode impressively broad world knowledge in their parameters. However, the knowledge in static language models falls out of date, limiting the model's effective "shelf life." While online fine-tuning can reduce this degradation, we find that naively fine-tuning on a stream of documents leads to a low level of information uptake. We hypothesize that online fine-tuning does not sufficiently attend to important information. That is, the gradient signal from important tokens representing factual information is drowned out by the gradient from inherently noisy tokens, suggesting that a dynamic, context-aware learning rate may be beneficial. We therefore propose learning which tokens to upweight. We meta-train a small, autoregressive model to reweight the language modeling loss for each token during online fine-tuning, with the objective of maximizing the out-of-date base question-answering model's ability to answer questions about a document after a single weighted gradient step. We call this approach Context-aware Meta-learned Loss Scaling (CaMeLS). Across three different distributions of documents, our experiments find that CaMeLS provides substantially improved information uptake on streams of thousands of documents compared with standard fine-tuning and baseline heuristics for reweighting token losses.

Meta-Learning Online Adaptation of Language Models

TL;DR

Knowledge in large language models becomes stale as the world changes, and naive online fine-tuning yields poor information uptake. CaMeLS introduces a bi-level, meta-learned token-weighting mechanism that reweights online losses to focus on informative tokens, using a lightweight weighting model trained with a proxy base model. The approach yields substantial improvements in knowledge retention across streaming QA datasets and transfers to much larger models, with interpretable weights that emphasize numbers and proper nouns and clear context dependence. These results suggest a practical, scalable path to keep language models up-to-date without annotated token-level supervision, broadening their applicability in dynamic information environments.

Abstract

Large language models encode impressively broad world knowledge in their parameters. However, the knowledge in static language models falls out of date, limiting the model's effective "shelf life." While online fine-tuning can reduce this degradation, we find that naively fine-tuning on a stream of documents leads to a low level of information uptake. We hypothesize that online fine-tuning does not sufficiently attend to important information. That is, the gradient signal from important tokens representing factual information is drowned out by the gradient from inherently noisy tokens, suggesting that a dynamic, context-aware learning rate may be beneficial. We therefore propose learning which tokens to upweight. We meta-train a small, autoregressive model to reweight the language modeling loss for each token during online fine-tuning, with the objective of maximizing the out-of-date base question-answering model's ability to answer questions about a document after a single weighted gradient step. We call this approach Context-aware Meta-learned Loss Scaling (CaMeLS). Across three different distributions of documents, our experiments find that CaMeLS provides substantially improved information uptake on streams of thousands of documents compared with standard fine-tuning and baseline heuristics for reweighting token losses.
Paper Structure (20 sections, 4 equations, 9 figures, 7 tables)

This paper contains 20 sections, 4 equations, 9 figures, 7 tables.

Figures (9)

  • Figure 1: The proposed method CaMeLS learns to rescale the per-token online loss, sparsifying the fine-tuning gradients to emphasize informative timesteps. The middle row shows the weights output by CaMeLS. The top and bottom rows show raw and weighted per-token gradient norms, respectively.
  • Figure 2: We study the setting of a language model being adapted unsupervised (without annotation of important tokens) on an online stream of documents, and being later evaluated on queries (e.g., questions) about those documents. Downstream inputs are not provided during the adaptation phase, requiring the model to integrate as much information as possible about the documents.
  • Figure 3: A single step of CaMeLS meta-training. In step 1, the weighting model (red) produces a set of importance weights over the tokens in a given document. In step 2, the base model (blue) is updated using a single gradient step on the weighted NLL, producing an adapted model (pink). In step 3, the weighting model is updated to improve the adapted base model's ability to answer questions about the document. During test-time adaptation, steps 1 and 2 are applied repeatedly for each document in the test document stream.
  • Figure 4: CaMeLS's meta-learned weights improve knowledge uptake after online language model adaptation on a stream of data. The F1 score of the base model before and after adaptation with CaMeLS are computed on questions about the documents used for adaptation. The relative change in F1 is plotted. Top, lower left, and lower right show StreamingQA, SQuAD, and ArchivalQA datasets, respectively. Error bars are standard error over 4 sampled streams of test data.
  • Figure 5: The importance weight distribution learned by CaMeLS is bimodal, with proper nouns and numbers being the parts of speech most likely to have high importance weights. The overall importance weight distribution (left) and the distribution conditioned by part of speech (right) are shown on the validation split of StreamingQA.
  • ...and 4 more figures