Table of Contents
Fetching ...

Discovering Knowledge-Critical Subnetworks in Pretrained Language Models

Deniz Bayazit, Negar Foroutan, Zeming Chen, Gail Weiss, Antoine Bosselut

TL;DR

This work investigates whether pretrained language models contain various *knowledge-critical* subnetworks: particular sparse computational subgraphs that can, if removed, precisely suppress specific knowledge the model has memorized and proposes a multi-objective differentiable masking scheme that can be applied to both weights and neurons to discover such subnetworks.

Abstract

Pretrained language models (LMs) encode implicit representations of knowledge in their parameters. However, localizing these representations and disentangling them from each other remains an open problem. In this work, we investigate whether pretrained language models contain various knowledge-critical subnetworks: particular sparse computational subgraphs that can, if removed, precisely suppress specific knowledge the model has memorized. We propose a multi-objective differentiable masking scheme that can be applied to both weights and neurons to discover such subnetworks and show that we can use them to precisely remove specific knowledge from models while minimizing adverse effects on the behavior of the original model. We demonstrate our method on multiple GPT2 variants, uncovering highly sparse subnetworks (98%+ sparsity) that are critical for expressing specific collections of relational knowledge. When these subnetworks are removed, the remaining network maintains most of its initial abilities but struggles to represent the suppressed knowledge.

Discovering Knowledge-Critical Subnetworks in Pretrained Language Models

TL;DR

This work investigates whether pretrained language models contain various *knowledge-critical* subnetworks: particular sparse computational subgraphs that can, if removed, precisely suppress specific knowledge the model has memorized and proposes a multi-objective differentiable masking scheme that can be applied to both weights and neurons to discover such subnetworks.

Abstract

Pretrained language models (LMs) encode implicit representations of knowledge in their parameters. However, localizing these representations and disentangling them from each other remains an open problem. In this work, we investigate whether pretrained language models contain various knowledge-critical subnetworks: particular sparse computational subgraphs that can, if removed, precisely suppress specific knowledge the model has memorized. We propose a multi-objective differentiable masking scheme that can be applied to both weights and neurons to discover such subnetworks and show that we can use them to precisely remove specific knowledge from models while minimizing adverse effects on the behavior of the original model. We demonstrate our method on multiple GPT2 variants, uncovering highly sparse subnetworks (98%+ sparsity) that are critical for expressing specific collections of relational knowledge. When these subnetworks are removed, the remaining network maintains most of its initial abilities but struggles to represent the suppressed knowledge.
Paper Structure (61 sections, 8 equations, 13 figures, 21 tables)

This paper contains 61 sections, 8 equations, 13 figures, 21 tables.

Figures (13)

  • Figure 1: Knowledge-critical subnetworks are necessary for expressing target knowledge triplets (TargetKG) in LMs. When removed, the remaining model no longer expresses the specific triplets, but maintains its ability to express other relational knowledge (ControlKG) and its language modeling abilities (ControlLM).
  • Figure 2: Removing and adding parameters to the remaining GPT2-small model, averaged over five seeds, with standard deviation depicted as the filled area around the average curves. The $x$-axis is the removed subnetwork sparsity. The $y$-axis is the $\Delta$PPL = PPL($f(x, \tilde{{\bm{m}}} \odot {\bm{\theta}})$) - PPL($f(x, {\bm{\theta}})$) for the different datasets. Vertical dashed lines show the original sparsity of the critical subnetwork. The darker curve is the outcome starting from the critical subnetwork, whereas the lighter curve is from a randomly masked model at the same sparsity.
  • Figure 3: Average module mask density with weight masking, for different KGs ( representation, location, and communication) and seeds. Reported in percentage (%). The brighter the color, the higher the removed mask density.
  • Figure 4: Average module mask density with neuron masking, for different KGs ( representation, location, and communication) and seeds. Reported in percentage (%). The brighter the color, the higher the removed mask density.
  • Figure 5: Density percentage (%) of different heads across different attention layers for weight masking. Each row represents a different KG and each column is a different seed.
  • ...and 8 more figures