Table of Contents
Fetching ...

Sharpness-Aware Machine Unlearning

Haoran Tang, Rajiv Khanna

TL;DR

This work addresses the challenge of forgetting specific training data without retraining from scratch by analyzing Sharpness-Aware Minimization (SAM) within a signal-noise unlearning framework. It reveals that SAM’s usual denoising effect can break when optimizing forget signals with NegGrad, prompting a revised understanding of the retain-forget balance and introducing a signal-strength dependent threshold. Building on these insights, the authors propose Sharp MinMax, a two-branch approach that preserves retain performance via SAM on the retain branch while aggressively forgetting the target data through sharpness maximization on a dedicated forget branch. Extensive experiments on CIFAR-100 and ImageNet demonstrate that SAM-enhanced unlearning improves the forget-retain-test tradeoff (ToW), reduces feature entanglement, and improves robustness to membership inference attacks, with results that generalize across data corruptions, optimizers, and architectures. The work thus provides both theoretical and practical advances for scalable, reliable machine unlearning.

Abstract

We characterize the effectiveness of Sharpness-aware minimization (SAM) under machine unlearning scheme, where unlearning forget signals interferes with learning retain signals. While previous work prove that SAM improves generalization with noise memorization prevention, we show that SAM abandons such denoising property when fitting the forget set, leading to altered generalization depending on signal strength. We further characterize the signal surplus of SAM in the order of signal strength, which enables learning from less retain signals to maintain model performance and putting more weight on unlearning the forget set. Empirical studies show that SAM outperforms SGD with relaxed requirement for retain signals and can enhance various unlearning methods either as pretrain or unlearn algorithm. Motivated by our refined characterization of SAM unlearning and observing that overfitting can benefit more stringent sample-specific unlearning, we propose Sharp MinMax, which splits the model into two to learn retain signals with SAM and unlearn forget signals with sharpness maximization, achieving best performance. Extensive experiments show that SAM enhances unlearning across varying difficulties measured by memorization, yielding decreased feature entanglement between retain and forget sets, stronger resistance to membership inference attacks, and a flatter loss landscape. Our observations generalize to more noised data, different optimizers, and different architectures.

Sharpness-Aware Machine Unlearning

TL;DR

This work addresses the challenge of forgetting specific training data without retraining from scratch by analyzing Sharpness-Aware Minimization (SAM) within a signal-noise unlearning framework. It reveals that SAM’s usual denoising effect can break when optimizing forget signals with NegGrad, prompting a revised understanding of the retain-forget balance and introducing a signal-strength dependent threshold. Building on these insights, the authors propose Sharp MinMax, a two-branch approach that preserves retain performance via SAM on the retain branch while aggressively forgetting the target data through sharpness maximization on a dedicated forget branch. Extensive experiments on CIFAR-100 and ImageNet demonstrate that SAM-enhanced unlearning improves the forget-retain-test tradeoff (ToW), reduces feature entanglement, and improves robustness to membership inference attacks, with results that generalize across data corruptions, optimizers, and architectures. The work thus provides both theoretical and practical advances for scalable, reliable machine unlearning.

Abstract

We characterize the effectiveness of Sharpness-aware minimization (SAM) under machine unlearning scheme, where unlearning forget signals interferes with learning retain signals. While previous work prove that SAM improves generalization with noise memorization prevention, we show that SAM abandons such denoising property when fitting the forget set, leading to altered generalization depending on signal strength. We further characterize the signal surplus of SAM in the order of signal strength, which enables learning from less retain signals to maintain model performance and putting more weight on unlearning the forget set. Empirical studies show that SAM outperforms SGD with relaxed requirement for retain signals and can enhance various unlearning methods either as pretrain or unlearn algorithm. Motivated by our refined characterization of SAM unlearning and observing that overfitting can benefit more stringent sample-specific unlearning, we propose Sharp MinMax, which splits the model into two to learn retain signals with SAM and unlearn forget signals with sharpness maximization, achieving best performance. Extensive experiments show that SAM enhances unlearning across varying difficulties measured by memorization, yielding decreased feature entanglement between retain and forget sets, stronger resistance to membership inference attacks, and a flatter loss landscape. Our observations generalize to more noised data, different optimizers, and different architectures.

Paper Structure

This paper contains 48 sections, 8 theorems, 61 equations, 9 figures, 17 tables, 2 algorithms.

Key Result

Lemma 3.1

(Noise memorization of $\mathcal{F}$ by SAM under NegGrad). Under the NegGrad scheme and the Assumption assumption:assumption holds, for class $j$ we have that if for SGD: $\langle\mathbf{w}^{\mathbf{t}}, \bm{\xi}_k\rangle \geq 0, k \in \mathcal{I}_{\mathbf{t}}^{\mathcal{R}}$ and $j=y_k$, then for S

Figures (9)

  • Figure 1: UMAP mcinnes2018umap feature analysis on Mid Mem $\mathcal{F}_\text{mid}$. At all-class level, we observe that SAM better maintains class clusters after unlearning while SGD is forming a more evident clump of features; at classwise level, we observe that while both push away forget features, SGD also scatters retain features further, suggesting overfitting. This also explains the larger clump of SGD at all-class level. We observe that SAM further pushes away forget features on $\mathcal{F}_\text{high}$ and SGD scatters more retain features on $\mathcal{F}_\text{low}$, see App. \ref{['supp:feature']} for full visualizations.
  • Figure 2: As $\alpha$ decreases, NegGrad puts less weight on retain signals and learns more from $\mathcal{F}$, leading to harmful overfitting. SAM exhibits more tolerance to insufficient retain signals, while $\mathcal{A},\mathcal{U}=\text{SGD}$ collapses the fastest. Note that $\operatorname{ToW}$ starts failing before $\alpha=|\mathcal{R}|/(|\mathcal{F}|+|\mathcal{R}|)$, implying more factors affecting $\alpha$ threshold as we point out.
  • Figure 3: Loss landscapes on $\mathcal{D}_{\text{test}}$ and $\mathcal{F}_{\text{mid}}$, where first row shows a SAM pretrained model and SAM unlearned models, and second row shows SGD counterparts. While unlearning increases sharpness as suggested by reduced basin ratios, we observe SAM unlearned models still maintain flatter landscapes than SGD models do.
  • Figure 4: $95\%$ confidence intervals $(\mu\pm2\sigma)$ of unlearning methods on ImageNet, in accordance to Tab. \ref{['tab:unlearn']} and Tab. \ref{['tab:minmax']}. We run each setting three times with different seeds and compute the statistical significance. SAM consistently improves base $\mathcal{U}$, and we observe ASAM 1.0 to bring largest improvement steadily.
  • Figure 5: $95\%$ confidence intervals $(\mu\pm2\sigma)$ of unlearning methods on CIFAR-100, in accordance to Tab. \ref{['tab:unlearn']} and Tab. \ref{['tab:minmax']}. We run each setting three times with different seeds and compute the statistical significance. SAM not only improves $\operatorname{ToW}$ of the based methods, but also more robust against variance than SGD.
  • ...and 4 more figures

Theorems & Definitions (10)

  • Lemma 3.1
  • Theorem 3.2
  • Theorem 3.3
  • Corollary 3.3.1
  • Lemma 3.4
  • Remark D.2
  • Lemma D.3
  • Lemma D.4
  • Lemma D.5
  • Remark D.6