Table of Contents
Fetching ...

Extractive Structures Learned in Pretraining Enable Generalization on Finetuned Facts

Jiahai Feng, Stuart Russell, Jacob Steinhardt

TL;DR

The paper tackles how pretrained LMs generalize to the implications of facts they are finetuned on, introducing extractive structures as a three-part mechanism (informative, upstream, downstream) that coordinates via causal interventions to enable OCR. It provides linearized metrics to identify these components and demonstrates their existence in several models, revealing that fact storage occurs across multiple layers with distinct first-hop and second-hop generalization roles. The authors also propose that extractive structures form during pretraining when encountering implications of known facts, predicting a data-ordering effect and a weight-grafting mechanism that can transfer OCR capabilities to counterfactual scenarios. These insights contribute toward a dynamical understanding of generalization in LMs and hint at strategies for robust knowledge editing and safe deployment.

Abstract

Pretrained language models (LMs) can generalize to implications of facts that they are finetuned on. For example, if finetuned on ``John Doe lives in Tokyo," LMs can correctly answer ``What language do the people in John Doe's city speak?'' with ``Japanese''. However, little is known about the mechanisms that enable this generalization or how they are learned during pretraining. We introduce extractive structures as a framework for describing how components in LMs (e.g., MLPs or attention heads) coordinate to enable this generalization. The structures consist of informative components that store training facts as weight changes, and upstream and downstream extractive components that query and process the stored information to produce the correct implication. We hypothesize that extractive structures are learned during pretraining when encountering implications of previously known facts. This yields two predictions: a data ordering effect where extractive structures can be learned only if facts precede their implications, and a weight grafting effect where extractive structures can be transferred to predict counterfactual implications. We empirically demonstrate these phenomena in the OLMo-7b, Llama 3-8b, Gemma 2-9b, and Qwen 2-7b models. Of independent interest, our results also indicate that fact learning can occur at both early and late layers, which lead to different forms of generalization.

Extractive Structures Learned in Pretraining Enable Generalization on Finetuned Facts

TL;DR

The paper tackles how pretrained LMs generalize to the implications of facts they are finetuned on, introducing extractive structures as a three-part mechanism (informative, upstream, downstream) that coordinates via causal interventions to enable OCR. It provides linearized metrics to identify these components and demonstrates their existence in several models, revealing that fact storage occurs across multiple layers with distinct first-hop and second-hop generalization roles. The authors also propose that extractive structures form during pretraining when encountering implications of known facts, predicting a data-ordering effect and a weight-grafting mechanism that can transfer OCR capabilities to counterfactual scenarios. These insights contribute toward a dynamical understanding of generalization in LMs and hint at strategies for robust knowledge editing and safe deployment.

Abstract

Pretrained language models (LMs) can generalize to implications of facts that they are finetuned on. For example, if finetuned on ``John Doe lives in Tokyo," LMs can correctly answer ``What language do the people in John Doe's city speak?'' with ``Japanese''. However, little is known about the mechanisms that enable this generalization or how they are learned during pretraining. We introduce extractive structures as a framework for describing how components in LMs (e.g., MLPs or attention heads) coordinate to enable this generalization. The structures consist of informative components that store training facts as weight changes, and upstream and downstream extractive components that query and process the stored information to produce the correct implication. We hypothesize that extractive structures are learned during pretraining when encountering implications of previously known facts. This yields two predictions: a data ordering effect where extractive structures can be learned only if facts precede their implications, and a weight grafting effect where extractive structures can be transferred to predict counterfactual implications. We empirically demonstrate these phenomena in the OLMo-7b, Llama 3-8b, Gemma 2-9b, and Qwen 2-7b models. Of independent interest, our results also indicate that fact learning can occur at both early and late layers, which lead to different forms of generalization.

Paper Structure

This paper contains 28 sections, 18 equations, 20 figures, 4 tables.

Figures (20)

  • Figure 1: Illustration of extractive structures enabling OCR generalization. Left: Finetuning on the fact "John Doe lives in Tokyo" encodes the association "John Doe"$\rightarrow$"Tokyo" in the weights of informative components. Right: At test time, upstream structures recall the stored fact by querying informative components with "John Doe", and downstream structures post-process the extracted information into the correct response ("Tokyo"$\rightarrow$"Japanese").
  • Figure 2: Key empirical predictions of our framework. Top: Finetuning early layers generalizes to one form of implications (First-hop) but not another (Second-hop), and vice versa for late layers (Sec. \ref{['sec: two-hop causal']}). Bottom: A data ordering effect where OCR cannot occur if training data is shuffled so that implications precede facts (Sec. \ref{['sec: data ordering']}).
  • Figure 3: Mean ranks of facts and their implications when finetuning OLMo-7b on facts from the First-hop and Second-hop datasets. Lower rank is better. The mean rank of implications falls during finetuning; the LM thus generalizes to implications despite only training on facts.
  • Figure 4: Computational graph of a LM, focusing on component $\mathtt{C}$. Components in earlier and later layers are folded into $\mathbf W_{\text{early}}$ and $\mathbf W_{\text{late}}$ respectively. The direct arrow from $z_{\mathtt{C}}$ to $\mathcal{R}$ models skip connections in transformers.
  • Figure 5: Extractive scores for the First-hop (left) and Second-hop datasets (right). Scores are averaged over the dataset. We visualize only the scores of the last two entity tokens. The attention scores are summed across all the attention heads that are outputting to the same position. The First-hop informative scores and Second-hop upstream scores point to the early-middle MLPs at head entity tokens as the first-hop recall components. The First-hop downstream scores and Second-hop informative scores point to the early-middle MLPs at the last token as the second-hop recall components.
  • ...and 15 more figures