Learning to Forget using Hypernetworks
Jose Miguel Lara Rangel, Stefan Schoepf, Jack Foster, David Krueger, Usman Anwar
TL;DR
HyperForget addresses machine unlearning by proposing a diffusion-based hypernetwork framework that samples parameters to forget a targeted data subset while preserving performance on retained data. The method casts forgetting as a generative process, training a hypernetwork to produce parameter samples that maximize loss on the forget set $D_f$ and minimize loss on the retain set $D_r$, enabling efficient, dynamic unlearning without full retraining. Two implementations, DiHyFo-1 and DiHyFo-2, condition the diffusion hypernetwork on different loss signals and are validated on MNIST-based tasks, showing zero accuracy on $D_f$ while maintaining performance on $D_r$, with metrics like prompt alignment and MIA indicating close resemblance to retrained baselines. The work demonstrates a promising direction for adaptive, model-intrinsic unlearning and highlights challenges around scalability, privacy guarantees, and broader dataset applicability.
Abstract
Machine unlearning is gaining increasing attention as a way to remove adversarial data poisoning attacks from already trained models and to comply with privacy and AI regulations. The objective is to unlearn the effect of undesired data from a trained model while maintaining performance on the remaining data. This paper introduces HyperForget, a novel machine unlearning framework that leverages hypernetworks - neural networks that generate parameters for other networks - to dynamically sample models that lack knowledge of targeted data while preserving essential capabilities. Leveraging diffusion models, we implement two Diffusion HyperForget Networks and used them to sample unlearned models in Proof-of-Concept experiments. The unlearned models obtained zero accuracy on the forget set, while preserving good accuracy on the retain sets, highlighting the potential of HyperForget for dynamic targeted data removal and a promising direction for developing adaptive machine unlearning algorithms.
