Table of Contents
Fetching ...

A mean teacher algorithm for unlearning of language models

Yegor Klochkov

TL;DR

This work tackles the challenge of unlearning in large language models by revisiting the mean teacher framework as a proximal optimization method that traces a slow natural gradient descent trajectory to minimize memorization while preserving utility. It introduces NLUL, a simple loss that avoids gradient vanishing and pairs well with mean teacher, and demonstrates competitive performance on the MUSE benchmarks using a pretraining data subset for regularization. The results reveal a trade-off: substantial reduction of verbatim and knowledge memorization and lower privacy leakage can come with declines in global knowledge metrics like MMLU, and knowledge can re-emerge after unrelated fine-tuning, underscoring the need for robust evaluation. Overall, mean teacher with NLUL offers a practical, competitive approach to weight-based unlearning with favorable privacy implications, complemented by reproducible experiments and open-source code.

Abstract

One of the goals of language model unlearning is to reduce memorization of selected text instances while retaining the model's general abilities. Despite various proposed methods, reducing memorization of large datasets without noticeable degradation in model utility remains challenging. In this paper, we investigate the mean teacher algorithm (Tarvainen & Valpola, 2017), a simple proximal optimization method from continual learning literature that gradually modifies the teacher model. We show that the mean teacher can approximate a trajectory of a slow natural gradient descent (NGD), which inherently seeks low-curvature updates that are less likely to degrade the model utility. While slow NGD can suffer from vanishing gradients, we introduce a new unlearning loss called "negative log-unlikelihood" (NLUL) that avoids this problem. We show that the combination of mean teacher and NLUL improves some metrics on the MUSE benchmarks (Shi et al., 2024).

A mean teacher algorithm for unlearning of language models

TL;DR

This work tackles the challenge of unlearning in large language models by revisiting the mean teacher framework as a proximal optimization method that traces a slow natural gradient descent trajectory to minimize memorization while preserving utility. It introduces NLUL, a simple loss that avoids gradient vanishing and pairs well with mean teacher, and demonstrates competitive performance on the MUSE benchmarks using a pretraining data subset for regularization. The results reveal a trade-off: substantial reduction of verbatim and knowledge memorization and lower privacy leakage can come with declines in global knowledge metrics like MMLU, and knowledge can re-emerge after unrelated fine-tuning, underscoring the need for robust evaluation. Overall, mean teacher with NLUL offers a practical, competitive approach to weight-based unlearning with favorable privacy implications, complemented by reproducible experiments and open-source code.

Abstract

One of the goals of language model unlearning is to reduce memorization of selected text instances while retaining the model's general abilities. Despite various proposed methods, reducing memorization of large datasets without noticeable degradation in model utility remains challenging. In this paper, we investigate the mean teacher algorithm (Tarvainen & Valpola, 2017), a simple proximal optimization method from continual learning literature that gradually modifies the teacher model. We show that the mean teacher can approximate a trajectory of a slow natural gradient descent (NGD), which inherently seeks low-curvature updates that are less likely to degrade the model utility. While slow NGD can suffer from vanishing gradients, we introduce a new unlearning loss called "negative log-unlikelihood" (NLUL) that avoids this problem. We show that the combination of mean teacher and NLUL improves some metrics on the MUSE benchmarks (Shi et al., 2024).

Paper Structure

This paper contains 22 sections, 3 theorems, 75 equations, 4 figures, 4 tables, 2 algorithms.

Key Result

Theorem 3.1

Set $\gamma := \kappa \alpha \eta / (1 - \kappa \eta)$ and $\overline{\lambda} := \lambda + (1 - \mu) \kappa / (1 - \eta \kappa)$. Consider the following updates Suppose that 1) $\kappa$ is a positive constant, 2) $\eta$, $\alpha$ are sufficiently small, 3) number of steps $T$ is such that $T \gamma$ is bounded by a constant, 4) $H(\theta)$ is symmetric positive-definite and satisfies Eq. local_q

Figures (4)

  • Figure 1: NLL loss on the forget set for MT using different unlearning losses (green). For IT we only show the KL divergence between the bad teacher and the target model. For NPO/LL we additionally perform 2 epochs with AdamW to "escape" the starting point.
  • Figure 2: Gradient norms during MT training in Figure \ref{['fig:loss_dynamics']}.
  • Figure 3: Comparison of mean teacher (blue), baselines with pretraining data (red), original baselines from shi2024muse that use retain split for regularization (green), and recent new methods as reported in bu2024unlearningwang2024llmfan2024simplicitywang2025gru(orange).
  • Figure 4: Sustainability of unlearning: how utility preserves with sequential unlearning requests. We perform experiment for mean teacher and NPO, both using pretraining data instead of MUSE-News retain split.

Theorems & Definitions (5)

  • Theorem 3.1
  • Lemma 4.1
  • proof
  • Lemma 4.3
  • proof