From Memorization to Reasoning in the Spectrum of Loss Curvature
Jack Merullo, Srihita Vatsavaya, Lucius Bushnaq, Owen Lewis
TL;DR
The paper addresses how transformer models memorize data versus reason over it by analyzing the loss landscape curvature and decomposing weight matrices with a K-FAC-based curvature basis. It introduces a weight-editing method that preserves a fraction of curvature mass while removing low-curvature directions associated with memorized recitations, achieving strong suppression of memorized content in both LMs and ViTs and often improving generalization (perplexity). The authors show that memorization and certain reasoning tasks lie along a spectrum in which arithmetic and closed-book fact recall are more sensitive to edits, whereas open-book retrieval and some logical reasoning are robust, suggesting these tasks rely on different curvature components. The work provides both theoretical insight into how memorization is distributed in weight space and a practical, unsupervised editing technique that outperforms a supervised unlearning baseline in many settings. By linking downstream behaviors to curvature bands, it offers a framework for understanding and regulating memory in neural networks with implications for privacy, copyright, and model robustness.
Abstract
We characterize how memorization is represented in transformer models and show that it can be disentangled in the weights of both language models (LMs) and vision transformers (ViTs) using a decomposition based on the loss landscape curvature. This insight is based on prior theoretical and empirical work showing that the curvature for memorized training points is much sharper than non memorized, meaning ordering weight components from high to low curvature can reveal a distinction without explicit labels. This motivates a weight editing procedure that suppresses far more recitation of untargeted memorized data more effectively than a recent unlearning method (BalancedSubnet), while maintaining lower perplexity. Since the basis of curvature has a natural interpretation for shared structure in model weights, we analyze the editing procedure extensively on its effect on downstream tasks in LMs, and find that fact retrieval and arithmetic are specifically and consistently negatively affected, even though open book fact retrieval and general logical reasoning is conserved. We posit these tasks rely heavily on specialized directions in weight space rather than general purpose mechanisms, regardless of whether those individual datapoints are memorized. We support this by showing a correspondence between task data's activation strength with low curvature components that we edit out, and the drop in task performance after the edit. Our work enhances the understanding of memorization in neural networks with practical applications towards removing it, and provides evidence for idiosyncratic, narrowly-used structures involved in solving tasks like math and fact retrieval.
