Sharpness-Aware Machine Unlearning
Haoran Tang, Rajiv Khanna
TL;DR
This work addresses the challenge of forgetting specific training data without retraining from scratch by analyzing Sharpness-Aware Minimization (SAM) within a signal-noise unlearning framework. It reveals that SAM’s usual denoising effect can break when optimizing forget signals with NegGrad, prompting a revised understanding of the retain-forget balance and introducing a signal-strength dependent threshold. Building on these insights, the authors propose Sharp MinMax, a two-branch approach that preserves retain performance via SAM on the retain branch while aggressively forgetting the target data through sharpness maximization on a dedicated forget branch. Extensive experiments on CIFAR-100 and ImageNet demonstrate that SAM-enhanced unlearning improves the forget-retain-test tradeoff (ToW), reduces feature entanglement, and improves robustness to membership inference attacks, with results that generalize across data corruptions, optimizers, and architectures. The work thus provides both theoretical and practical advances for scalable, reliable machine unlearning.
Abstract
We characterize the effectiveness of Sharpness-aware minimization (SAM) under machine unlearning scheme, where unlearning forget signals interferes with learning retain signals. While previous work prove that SAM improves generalization with noise memorization prevention, we show that SAM abandons such denoising property when fitting the forget set, leading to altered generalization depending on signal strength. We further characterize the signal surplus of SAM in the order of signal strength, which enables learning from less retain signals to maintain model performance and putting more weight on unlearning the forget set. Empirical studies show that SAM outperforms SGD with relaxed requirement for retain signals and can enhance various unlearning methods either as pretrain or unlearn algorithm. Motivated by our refined characterization of SAM unlearning and observing that overfitting can benefit more stringent sample-specific unlearning, we propose Sharp MinMax, which splits the model into two to learn retain signals with SAM and unlearn forget signals with sharpness maximization, achieving best performance. Extensive experiments show that SAM enhances unlearning across varying difficulties measured by memorization, yielding decreased feature entanglement between retain and forget sets, stronger resistance to membership inference attacks, and a flatter loss landscape. Our observations generalize to more noised data, different optimizers, and different architectures.
