A mean teacher algorithm for unlearning of language models
Yegor Klochkov
TL;DR
This work tackles the challenge of unlearning in large language models by revisiting the mean teacher framework as a proximal optimization method that traces a slow natural gradient descent trajectory to minimize memorization while preserving utility. It introduces NLUL, a simple loss that avoids gradient vanishing and pairs well with mean teacher, and demonstrates competitive performance on the MUSE benchmarks using a pretraining data subset for regularization. The results reveal a trade-off: substantial reduction of verbatim and knowledge memorization and lower privacy leakage can come with declines in global knowledge metrics like MMLU, and knowledge can re-emerge after unrelated fine-tuning, underscoring the need for robust evaluation. Overall, mean teacher with NLUL offers a practical, competitive approach to weight-based unlearning with favorable privacy implications, complemented by reproducible experiments and open-source code.
Abstract
One of the goals of language model unlearning is to reduce memorization of selected text instances while retaining the model's general abilities. Despite various proposed methods, reducing memorization of large datasets without noticeable degradation in model utility remains challenging. In this paper, we investigate the mean teacher algorithm (Tarvainen & Valpola, 2017), a simple proximal optimization method from continual learning literature that gradually modifies the teacher model. We show that the mean teacher can approximate a trajectory of a slow natural gradient descent (NGD), which inherently seeks low-curvature updates that are less likely to degrade the model utility. While slow NGD can suffer from vanishing gradients, we introduce a new unlearning loss called "negative log-unlikelihood" (NLUL) that avoids this problem. We show that the combination of mean teacher and NLUL improves some metrics on the MUSE benchmarks (Shi et al., 2024).
