Unlearning Traces the Influential Training Data of Language Models
Masaru Isonuma, Ivan Titov
TL;DR
This work tackles the problem of attributing LLM outputs to their training data without costly retraining. It introduces UnTrac, which unlearns entire training datasets via gradient ascent and measures the resulting change in test loss, and UnTrac-Inv, which unlearns a test dataset to efficiently approximate the same influence on training data. The authors show that these methods yield higher correlation with ground-truth influence than existing Hessian-based, gradient-based, and sampling-based influence functions, across finetuning and pretraining scenarios, with good scalability and modest memory. The approach provides a practical tool for tracing the sources of harmful, biased, or false content in large language models and offers guidance on hyperparameter choices and optimizer usage in real-world settings.
Abstract
Identifying the training datasets that influence a language model's outputs is essential for minimizing the generation of harmful content and enhancing its performance. Ideally, we can measure the influence of each dataset by removing it from training; however, it is prohibitively expensive to retrain a model multiple times. This paper presents UnTrac: unlearning traces the influence of a training dataset on the model's performance. UnTrac is extremely simple; each training dataset is unlearned by gradient ascent, and we evaluate how much the model's predictions change after unlearning. Furthermore, we propose a more scalable approach, UnTrac-Inv, which unlearns a test dataset and evaluates the unlearned model on training datasets. UnTrac-Inv resembles UnTrac, while being efficient for massive training datasets. In the experiments, we examine if our methods can assess the influence of pretraining datasets on generating toxic, biased, and untruthful content. Our methods estimate their influence much more accurately than existing methods while requiring neither excessive memory space nor multiple checkpoints.
