Table of Contents
Fetching ...

Estimating Training Data Influence by Tracing Gradient Descent

Garima Pruthi, Frederick Liu, Mukund Sundararajan, Satyen Kale

TL;DR

TracIn presents a gradient-based framework to quantify how individual training examples influence a specific test prediction by tracing loss changes along the training trajectory. It moves from an idealized, step-wise formulation to practical approximations that leverage first-order gradients, minibatches, and a checkpoint-based replay (TracInCP) to scale to large models. Empirical results on CIFAR-10, MNIST, and ImageNet show TracIn outperforms influence function and representer baselines in identifying mislabelled data and provides actionable data-centric insights across regression and classification tasks. The method is simple to implement, broadly applicable to any SGD-trained model, and supports diverse applications from data cleaning to active-learning-style data fortification.

Abstract

We introduce a method called TracIn that computes the influence of a training example on a prediction made by the model. The idea is to trace how the loss on the test point changes during the training process whenever the training example of interest was utilized. We provide a scalable implementation of TracIn via: (a) a first-order gradient approximation to the exact computation, (b) saved checkpoints of standard training procedures, and (c) cherry-picking layers of a deep neural network. In contrast with previously proposed methods, TracIn is simple to implement; all it needs is the ability to work with gradients, checkpoints, and loss functions. The method is general. It applies to any machine learning model trained using stochastic gradient descent or a variant of it, agnostic of architecture, domain and task. We expect the method to be widely useful within processes that study and improve training data.

Estimating Training Data Influence by Tracing Gradient Descent

TL;DR

TracIn presents a gradient-based framework to quantify how individual training examples influence a specific test prediction by tracing loss changes along the training trajectory. It moves from an idealized, step-wise formulation to practical approximations that leverage first-order gradients, minibatches, and a checkpoint-based replay (TracInCP) to scale to large models. Empirical results on CIFAR-10, MNIST, and ImageNet show TracIn outperforms influence function and representer baselines in identifying mislabelled data and provides actionable data-centric insights across regression and classification tasks. The method is simple to implement, broadly applicable to any SGD-trained model, and supports diverse applications from data cleaning to active-learning-style data fortification.

Abstract

We introduce a method called TracIn that computes the influence of a training example on a prediction made by the model. The idea is to trace how the loss on the test point changes during the training process whenever the training example of interest was utilized. We provide a scalable implementation of TracIn via: (a) a first-order gradient approximation to the exact computation, (b) saved checkpoints of standard training procedures, and (c) cherry-picking layers of a deep neural network. In contrast with previously proposed methods, TracIn is simple to implement; all it needs is the ability to work with gradients, checkpoints, and loss functions. The method is general. It applies to any machine learning model trained using stochastic gradient descent or a variant of it, agnostic of architecture, domain and task. We expect the method to be widely useful within processes that study and improve training data.

Paper Structure

This paper contains 37 sections, 1 theorem, 9 equations, 10 figures, 2 tables.

Key Result

Lemma 3.1

Suppose the initial parameter vector before starting the training process is $w_0$, and the final parameter vector is $w_T$. Then $\sum_{i=1}^n \texttt{TracInIdeal}(z_i, z') = \ell(w_{0}, z') - \ell(w_T, z')$

Figures (10)

  • Figure 1: CIFAR-10 and MNIST Mislabelled Data Identification with TracIn Representer points, and Influence Functions. We use “Fraction of mislabelled identified” on the y axis to compare the effectiveness of each method. (RP = Random Projections, CP = CheckPoints)
  • Figure 2: Analysis of effect of approximations with Pearson correlation of first order approximate TracIn influences with heuristic influences over multiple checkpoints and with projections of different sizes. RP stands for random projection.
  • Figure 3: Proponents and opponents examples using TracIn, representer point, and influence functions. (Predicted class in brackets)
  • Figure 4: TracIn applied on Imagenet. Each row starts with the test example followed by three proponents and three opponents. The test image in the first row is classfied as band-aid and is the only misclassified example. (af-chameleon: african-chameleon, fr-bulldog: french-bulldog)
  • Figure 5: CIFAR-10 results: Proponents and opponents examples of a correctly classified cat for influence functions, representer point, and TracIn. (Predicted class in brackets)
  • ...and 5 more figures

Theorems & Definitions (5)

  • Lemma 3.1
  • Remark 3.2: Proponents and Opponents
  • Remark 3.3
  • Remark 3.4: Handling Variations of Training
  • Remark 3.5: Counterfactual Interpretation