The Heuristic Core: Understanding Subnetwork Generalization in Pretrained Language Models
Adithya Bhaskar, Dan Friedman, Danqi Chen
TL;DR
This work investigates how pretrained language models generalize beyond in-domain data by pruning subnetworks. It challenges the idea of competing subnetworks driving grokking-like generalization and instead identifies a heuristic core: a small set of attention heads shared across subnetworks that compute shallow features; generalization emerges when additional heads interact with this core. The findings, replicated in RoBERTa and GPT-2, suggest a general mechanism where simple, overlapping features are formed early and more complex generalization builds on them. The study has practical implications for pruning and interpretability, highlighting how robustness to OOD evaluation hinges on the interplay between core heads and supplementary components.
Abstract
Prior work has found that pretrained language models (LMs) fine-tuned with different random seeds can achieve similar in-domain performance but generalize differently on tests of syntactic generalization. In this work, we show that, even within a single model, we can find multiple subnetworks that perform similarly in-domain, but generalize vastly differently. To better understand these phenomena, we investigate if they can be understood in terms of "competing subnetworks": the model initially represents a variety of distinct algorithms, corresponding to different subnetworks, and generalization occurs when it ultimately converges to one. This explanation has been used to account for generalization in simple algorithmic tasks ("grokking"). Instead of finding competing subnetworks, we find that all subnetworks -- whether they generalize or not -- share a set of attention heads, which we refer to as the heuristic core. Further analysis suggests that these attention heads emerge early in training and compute shallow, non-generalizing features. The model learns to generalize by incorporating additional attention heads, which depend on the outputs of the "heuristic" heads to compute higher-level features. Overall, our results offer a more detailed picture of the mechanisms for syntactic generalization in pretrained LMs.
