Distributional Machine Unlearning via Selective Data Removal
Youssef Allouah, Rachid Guerraoui, Sanmi Koyejo
TL;DR
This work tackles domain-level unlearning in ML by introducing distributional unlearning, a framework that uses KL-divergence constraints to simultaneously maximize distance from an unwanted distribution and preserve a retained distribution. It establishes a formal $(oldsymbol{ extalpha}, oldsymbol{ extepsilon})$-Pareto frontier for Gaussian and exponential-family models, and proves downstream log-loss guarantees for edited data. The authors propose a distance-based selective removal algorithm, showing quadratic gains in sample efficiency over random deletion in low-divergence regimes and validating the approach across synthetic data and real-world tasks like CIFAR-10 and Jigsaw Toxic Comments, with synergy observations for existing sample-level unlearning methods. The results suggest that strong unlearning effects can be achieved with substantially smaller forget sets, enabling scalable and principled subpopulation unlearning with practical downstream robustness.
Abstract
Machine learning systems increasingly face requirements to remove entire domains of information -- such as toxic language or biases -- rather than individual user data. This task presents a dilemma: full removal of the unwanted domain data is computationally expensive, while random partial removal is statistically inefficient. We find that a domain's statistical influence is often concentrated in a small subset of its data samples, suggesting a path between ineffective partial removal and unnecessary complete removal. We formalize this as distributional unlearning: a framework to select a small subset that balances forgetting an unwanted distribution while preserving a desired one. Using Kullback-Leibler divergence constraints, we derive the exact removal-preservation Pareto frontier for exponential families and prove that models trained on the edited data achieve corresponding log-loss bounds. We propose a distance-based selection algorithm and show it is quadratically more sample-efficient than random removal in the challenging low-divergence regime. Experiments across synthetic, text, and image datasets (Jigsaw, CIFAR-10, SMS spam) show our method requires 15-82% less deletion than full removal for strong unlearning effects, e.g., halving initial forget set accuracy. Ultimately, by showing a small forget set often suffices, our framework lays the foundations for more scalable and rigorous subpopulation unlearning.
