Gauss-Newton Unlearning for the LLM Era
Lev McKinney, Anvith Thudi, Juhan Bae, Tara Rezaei, Nicolas Papernot, Sheila A. McIlraith, Roger Grosse
TL;DR
Gauss-Newton Unlearning for the LLM Era introduces K-FADE, a second-order unlearning method for large language models that uses forget/retain distributions and EK-FAC/K-FAC Hessian approximations to compute a small number of Gauss-Newton steps. By converting a constraint on the retain distribution into a weight-space update, it achieves simultaneous output suppression on the forget set and minimal disruption to retained behavior, often matching or approximating retraining without forgetting. Empirically, K-FADE delivers state-of-the-art performance on WMDP and ToFU benchmarks, preserves specificity (lower KL divergence on retained data), and supports transferring unlearning updates to finetuned models, with runtime competitive to first-order methods. Limitations include the absence of formal unlearning guarantees and vulnerability to full-rank fine-tuning attacks, but the approach scales to frontier models and suggests practical directions for benchmarking and privacy-preserving unlearning.
Abstract
Standard large language model training can create models that produce outputs their trainer deems unacceptable in deployment. The probability of these outputs can be reduced using methods such as LLM unlearning. However, unlearning a set of data (called the forget set) can degrade model performance on other distributions where the trainer wants to retain the model's behavior. To improve this trade-off, we demonstrate that using the forget set to compute only a few uphill Gauss-Newton steps provides a conceptually simple, state-of-the-art unlearning approach for LLMs. While Gauss-Newton steps adapt Newton's method to non-linear models, it is non-trivial to efficiently and accurately compute such steps for LLMs. Hence, our approach crucially relies on parametric Hessian approximations such as Kronecker-Factored Approximate Curvature (K-FAC). We call this combined approach K-FADE (K-FAC for Distribution Erasure). Our evaluation on the WMDP and ToFU benchmarks demonstrates that K-FADE suppresses outputs from the forget set and approximates, in output space, the results of retraining without the forget set. Critically, our method does this while altering the outputs on the retain set less than previous methods. This is because K-FADE transforms a constraint on the model's outputs across the entire retain set into a constraint on the model's weights, allowing the algorithm to minimally change the model's behavior on the retain set at each step. Moreover, the unlearning updates computed by K-FADE can be reapplied later if the model undergoes further training, allowing unlearning to be cheaply maintained.
