Training Data Attribution via Approximate Unrolled Differentiation
Juhan Bae, Wu Lin, Jonathan Lorraine, Roger Grosse
TL;DR
This work tackles training data attribution (TDA) in modern neural networks, where implicit-differentiation methods assume converged, unique optima and unrolled methods are costly for large models or multi-stage training. Source introduces a segmented, stationary unrolling approach that approximates the total derivative of final parameters with respect to downweighting a training point, yielding an influence-function–like estimator without storing all training checkpoints. By partitioning the training into segments and modeling segmentwise Hessians and gradients as stationary, Source derives a closed-form expression that combines segmental influence via matrix functions and a damped inverse Hessian mechanism; EK-FAC parameterization enables scalable Hessian handling. Empirically, Source outperforms existing TDA techniques on diverse tasks, particularly when models are non-converged or trained in multiple stages, and provides a practical middle ground between IF and full unrolling with favorable computational trade-offs. This offers a robust, scalable tool for data provenance, debugging, and dataset curation in complex training pipelines and large-scale models.
Abstract
Many training data attribution (TDA) methods aim to estimate how a model's behavior would change if one or more data points were removed from the training set. Methods based on implicit differentiation, such as influence functions, can be made computationally efficient, but fail to account for underspecification, the implicit bias of the optimization algorithm, or multi-stage training pipelines. By contrast, methods based on unrolling address these issues but face scalability challenges. In this work, we connect the implicit-differentiation-based and unrolling-based approaches and combine their benefits by introducing Source, an approximate unrolling-based TDA method that is computed using an influence-function-like formula. While being computationally efficient compared to unrolling-based approaches, Source is suitable in cases where implicit-differentiation-based approaches struggle, such as in non-converged models and multi-stage training pipelines. Empirically, Source outperforms existing TDA techniques in counterfactual prediction, especially in settings where implicit-differentiation-based approaches fall short.
