Table of Contents
Fetching ...

The Heuristic Core: Understanding Subnetwork Generalization in Pretrained Language Models

Adithya Bhaskar, Dan Friedman, Danqi Chen

TL;DR

This work investigates how pretrained language models generalize beyond in-domain data by pruning subnetworks. It challenges the idea of competing subnetworks driving grokking-like generalization and instead identifies a heuristic core: a small set of attention heads shared across subnetworks that compute shallow features; generalization emerges when additional heads interact with this core. The findings, replicated in RoBERTa and GPT-2, suggest a general mechanism where simple, overlapping features are formed early and more complex generalization builds on them. The study has practical implications for pruning and interpretability, highlighting how robustness to OOD evaluation hinges on the interplay between core heads and supplementary components.

Abstract

Prior work has found that pretrained language models (LMs) fine-tuned with different random seeds can achieve similar in-domain performance but generalize differently on tests of syntactic generalization. In this work, we show that, even within a single model, we can find multiple subnetworks that perform similarly in-domain, but generalize vastly differently. To better understand these phenomena, we investigate if they can be understood in terms of "competing subnetworks": the model initially represents a variety of distinct algorithms, corresponding to different subnetworks, and generalization occurs when it ultimately converges to one. This explanation has been used to account for generalization in simple algorithmic tasks ("grokking"). Instead of finding competing subnetworks, we find that all subnetworks -- whether they generalize or not -- share a set of attention heads, which we refer to as the heuristic core. Further analysis suggests that these attention heads emerge early in training and compute shallow, non-generalizing features. The model learns to generalize by incorporating additional attention heads, which depend on the outputs of the "heuristic" heads to compute higher-level features. Overall, our results offer a more detailed picture of the mechanisms for syntactic generalization in pretrained LMs.

The Heuristic Core: Understanding Subnetwork Generalization in Pretrained Language Models

TL;DR

This work investigates how pretrained language models generalize beyond in-domain data by pruning subnetworks. It challenges the idea of competing subnetworks driving grokking-like generalization and instead identifies a heuristic core: a small set of attention heads shared across subnetworks that compute shallow features; generalization emerges when additional heads interact with this core. The findings, replicated in RoBERTa and GPT-2, suggest a general mechanism where simple, overlapping features are formed early and more complex generalization builds on them. The study has practical implications for pruning and interpretability, highlighting how robustness to OOD evaluation hinges on the interplay between core heads and supplementary components.

Abstract

Prior work has found that pretrained language models (LMs) fine-tuned with different random seeds can achieve similar in-domain performance but generalize differently on tests of syntactic generalization. In this work, we show that, even within a single model, we can find multiple subnetworks that perform similarly in-domain, but generalize vastly differently. To better understand these phenomena, we investigate if they can be understood in terms of "competing subnetworks": the model initially represents a variety of distinct algorithms, corresponding to different subnetworks, and generalization occurs when it ultimately converges to one. This explanation has been used to account for generalization in simple algorithmic tasks ("grokking"). Instead of finding competing subnetworks, we find that all subnetworks -- whether they generalize or not -- share a set of attention heads, which we refer to as the heuristic core. Further analysis suggests that these attention heads emerge early in training and compute shallow, non-generalizing features. The model learns to generalize by incorporating additional attention heads, which depend on the outputs of the "heuristic" heads to compute higher-level features. Overall, our results offer a more detailed picture of the mechanisms for syntactic generalization in pretrained LMs.
Paper Structure (29 sections, 9 equations, 16 figures, 8 tables)

This paper contains 29 sections, 9 equations, 16 figures, 8 tables.

Figures (16)

  • Figure 1: We find different subnetworks in a pretrained LM that achieve similar in-domain performance but generalize differently. Prior work has explained similar generalization phenomena in synthetic tasks in terms of distinct subnetworks that compete during training. We instead find evidence of a heuristic core: a set of attention heads that appear in all generalizing subnetworks but, on their own, do not generalize.
  • Figure 2: Pruning a BERT model with different random seeds results in subnetworks that perform similarly in-domain but generalize differently. The dots refer to the accuracy of the pruned subnetworks, while solid lines indicate full model performance. MNLI/HANS: At $50\%$ sparsity, the subnetworks perform within $3\%$ of the model on MNLI but show varying generalization. At $70\%$ sparsity, the subnetworks behave as pure heuristics despite respectable MNLI accuracy. The trend also holds for QQP/PAWS, with sparsities of $30\%$ and $60\%$. Figure \ref{['fig:multiprunefull']} in Appendix \ref{['sec:fullresults']} shows the plot for all subcases of HANS.
  • Figure 3: Different subnetworks at $50\%$ sparsity (found by pruning with different random seeds) generalize partially to different subcases of HANS. Subnetwork #5 generalizes to the Constituent subcases, whereas #8 generalizes to the Lexical Overlap subcases. Subnetwork #2 does well on the Embed-If subcase of Constituent, but not Embed-Verb.
  • Figure 4: The OOD accuracy decreases fairly smoothly with sparsity ($3$ seeds). The drop in ID accuracy is slow and has low variance. We find no subnetworks sparser than $30\%$ generalizing as well as the full model on either dataset.
  • Figure 5: This heatmap quantifies the frequencies of attention heads in the $50\%$ and $70\%$ sparsity subnetworks (MNLI). Entry $(i,j)$ corresponds to the number of heads appearing in $i$/12 $50\%$ subnetworks and $j$/12 $70\%$ subnetworks. In particular, we note that nine attention heads appear in all of the subnetworks.
  • ...and 11 more figures