Table of Contents
Fetching ...

Attribute-to-Delete: Machine Unlearning via Datamodel Matching

Kristian Georgiev, Roy Rinberg, Sung Min Park, Shivam Garg, Andrew Ilyas, Aleksander Madry, Seth Neel

TL;DR

This work introduces a new machine unlearning technique that exhibits strong empirical performance even in non-convex settings, and proposes the following meta-algorithm, which is called Datamodel Matching (DMM), which uses data attribution to predict the output of the model if it were re-trained on all but the forget set points.

Abstract

Machine unlearning -- efficiently removing the effect of a small "forget set" of training data on a pre-trained machine learning model -- has recently attracted significant research interest. Despite this interest, however, recent work shows that existing machine unlearning techniques do not hold up to thorough evaluation in non-convex settings. In this work, we introduce a new machine unlearning technique that exhibits strong empirical performance even in such challenging settings. Our starting point is the perspective that the goal of unlearning is to produce a model whose outputs are statistically indistinguishable from those of a model re-trained on all but the forget set. This perspective naturally suggests a reduction from the unlearning problem to that of data attribution, where the goal is to predict the effect of changing the training set on a model's outputs. Thus motivated, we propose the following meta-algorithm, which we call Datamodel Matching (DMM): given a trained model, we (a) use data attribution to predict the output of the model if it were re-trained on all but the forget set points; then (b) fine-tune the pre-trained model to match these predicted outputs. In a simple convex setting, we show how this approach provably outperforms a variety of iterative unlearning algorithms. Empirically, we use a combination of existing evaluations and a new metric based on the KL-divergence to show that even in non-convex settings, DMM achieves strong unlearning performance relative to existing algorithms. An added benefit of DMM is that it is a meta-algorithm, in the sense that future advances in data attribution translate directly into better unlearning algorithms, pointing to a clear direction for future progress in unlearning.

Attribute-to-Delete: Machine Unlearning via Datamodel Matching

TL;DR

This work introduces a new machine unlearning technique that exhibits strong empirical performance even in non-convex settings, and proposes the following meta-algorithm, which is called Datamodel Matching (DMM), which uses data attribution to predict the output of the model if it were re-trained on all but the forget set points.

Abstract

Machine unlearning -- efficiently removing the effect of a small "forget set" of training data on a pre-trained machine learning model -- has recently attracted significant research interest. Despite this interest, however, recent work shows that existing machine unlearning techniques do not hold up to thorough evaluation in non-convex settings. In this work, we introduce a new machine unlearning technique that exhibits strong empirical performance even in such challenging settings. Our starting point is the perspective that the goal of unlearning is to produce a model whose outputs are statistically indistinguishable from those of a model re-trained on all but the forget set. This perspective naturally suggests a reduction from the unlearning problem to that of data attribution, where the goal is to predict the effect of changing the training set on a model's outputs. Thus motivated, we propose the following meta-algorithm, which we call Datamodel Matching (DMM): given a trained model, we (a) use data attribution to predict the output of the model if it were re-trained on all but the forget set points; then (b) fine-tune the pre-trained model to match these predicted outputs. In a simple convex setting, we show how this approach provably outperforms a variety of iterative unlearning algorithms. Empirically, we use a combination of existing evaluations and a new metric based on the KL-divergence to show that even in non-convex settings, DMM achieves strong unlearning performance relative to existing algorithms. An added benefit of DMM is that it is a meta-algorithm, in the sense that future advances in data attribution translate directly into better unlearning algorithms, pointing to a clear direction for future progress in unlearning.

Paper Structure

This paper contains 80 sections, 5 theorems, 36 equations, 20 figures, 1 table, 9 algorithms.

Key Result

Theorem 1

Let $S$ and $S_R$ be the full training set and the retain set respectively, with input matrices $X$ and $X_R$ and corresponding labels $y$ and $y_R$. Additionally, let $\theta_{\text{full}}$ and $\theta_*$ denote the optima of the ridge objective eq:ridge_objective for the full data $S$ and retain s

Figures (20)

  • Figure 1: Effective unlearning via predictive data attribution. We apply different approximate unlearning methods to trained DNNs to unlearn selected forget sets from CIFAR-10 and ImageNet-Living-17. KLoM scores (y-axis) measure the quality of unlearning by computing the distributional distance between unlearned predictions and oracle predictions (e.g., 0 means perfect unlearning). To contextualize each method's efficiency, we also show the amount of compute relative to full re-training (x-axis). We evaluate KLoM values over points in the forget, retain, and validation sets to ensure that unlearning is effective across all datapoints, and report the 95th percentile in each group; we also report their average (1st column). Our new methods leveraging data attribution (dm-direct and dmm) dominate the pareto frontier of existing unlearning methods, and approach the unlearning quality of oracle models (full re-training) at a much smaller fraction of the cost.
  • Figure 2: The missing targets problem. We apply the SCRUB kurmanji2023towards algorithm to unlearn a forget set of CIFAR-10, and measure how well different (random) points are unlearned over time. To quantify how well a given point $x$ is unlearned, we fit a Gaussian distribution to the outputs of oracle models at $x$, and compute the likelihood of the average outputs from unlearned models under this distribution. We track this likelihood (y-axis) for random points across the duration of unlearning algorithm (x-axis). For many examples in the forget set (shown in red), unlearning quality is hurt by training for too long as we lack access to oracle targets.
  • Figure 3: Oracle matching can efficiently approximate re-training. The KLoM metric (y-axis) measures the distributional difference between unlearned predictions and oracle predictions (0 being perfect). We also show the amount of compute relative to full re-training (x-axis). We evaluate KLoM values over points in the forget, retain, and validation sets and report the $95$th percentile in each group; we also report the average across groups (1st column).
  • Figure 4: Datamodels predict oracle outputs. We examine the accuracy of datamodel predictions for unlearning a CIFAR-10 forget set (ID 5). For random samples from the forget and retain sets, we compare the distribution (across multiple runs) of margins when evaluated on that example across three settings: i) null (model on full dataset); ii) oracle (model re-trained without forget set); and iii) unlearned (using dm-direct, applied to instances of null models). In every case, the predicted outputs (orange) closely match the ground-truth (oracle), demonstrating the effectiveness of datamodels as a proxy for oracle outputs.
  • Figure 5: Oracle matching circumvents the stopping time problem. We revisit the earlier analysis for SCRUB (left) and apply the same analysis to Oracle Matching (right). The red lines highlight examples in the forget set whose unlearning quality is hurt by training longer. This "overshooting" happens frequently with SCRUB, but only rarely with Oracle Matching.
  • ...and 15 more figures

Theorems & Definitions (10)

  • Definition 1: Exact unlearning ginart2019making
  • Definition 2: $(\varepsilon, \delta)$-unlearning neel2021deletion
  • Definition 3: KL divergence of margins (KLoM)
  • Theorem 1: Proof in \ref{['sec:thm_proof']}
  • Theorem 2: Proof in \ref{['app:som_separation']}
  • Theorem 2: Proof in \ref{['sec:thm_proof']}
  • proof
  • Theorem 2: Proof in \ref{['app:som_separation']}
  • proof
  • Theorem 3: Theorem 5.8 of garrigos2023handbook