Table of Contents
Fetching ...

Influence Functions for Scalable Data Attribution in Diffusion Models

Bruno Mlodozeniec, Runa Eschenhagen, Juhan Bae, Alexander Immer, David Krueger, Richard Turner

TL;DR

This work extends influence-function theory to diffusion models, addressing data attribution and interpretability by formulating Hessian-based attributions with scalable approximations. It advocates using an (E)K-FAC approximation of the generalized Gauss-Newton for the diffusion objective, recasting prior diffusion-attribution methods as design choices within this framework. Empirical results show improved data-attribution performance (e.g., LDS) over prior approaches and reduced need for hyperparameter tuning, though challenges remain in accurately predicting post-retraining measurement changes and in proxying marginal probabilities. The findings advance data-centric safety and copyright-attribution discussions for diffusion models and point to future work on better marginal-probability proxies and enhanced Hessian-approximation strategies.

Abstract

Diffusion models have led to significant advancements in generative modelling. Yet their widespread adoption poses challenges regarding data attribution and interpretability. In this paper, we aim to help address such challenges in diffusion models by developing an influence functions framework. Influence function-based data attribution methods approximate how a model's output would have changed if some training data were removed. In supervised learning, this is usually used for predicting how the loss on a particular example would change. For diffusion models, we focus on predicting the change in the probability of generating a particular example via several proxy measurements. We show how to formulate influence functions for such quantities and how previously proposed methods can be interpreted as particular design choices in our framework. To ensure scalability of the Hessian computations in influence functions, we systematically develop K-FAC approximations based on generalised Gauss-Newton matrices specifically tailored to diffusion models. We recast previously proposed methods as specific design choices in our framework and show that our recommended method outperforms previous data attribution approaches on common evaluations, such as the Linear Data-modelling Score (LDS) or retraining without top influences, without the need for method-specific hyperparameter tuning.

Influence Functions for Scalable Data Attribution in Diffusion Models

TL;DR

This work extends influence-function theory to diffusion models, addressing data attribution and interpretability by formulating Hessian-based attributions with scalable approximations. It advocates using an (E)K-FAC approximation of the generalized Gauss-Newton for the diffusion objective, recasting prior diffusion-attribution methods as design choices within this framework. Empirical results show improved data-attribution performance (e.g., LDS) over prior approaches and reduced need for hyperparameter tuning, though challenges remain in accurately predicting post-retraining measurement changes and in proxying marginal probabilities. The findings advance data-centric safety and copyright-attribution discussions for diffusion models and point to future work on better marginal-probability proxies and enhanced Hessian-approximation strategies.

Abstract

Diffusion models have led to significant advancements in generative modelling. Yet their widespread adoption poses challenges regarding data attribution and interpretability. In this paper, we aim to help address such challenges in diffusion models by developing an influence functions framework. Influence function-based data attribution methods approximate how a model's output would have changed if some training data were removed. In supervised learning, this is usually used for predicting how the loss on a particular example would change. For diffusion models, we focus on predicting the change in the probability of generating a particular example via several proxy measurements. We show how to formulate influence functions for such quantities and how previously proposed methods can be interpreted as particular design choices in our framework. To ensure scalability of the Hessian computations in influence functions, we systematically develop K-FAC approximations based on generalised Gauss-Newton matrices specifically tailored to diffusion models. We recast previously proposed methods as specific design choices in our framework and show that our recommended method outperforms previous data attribution approaches on common evaluations, such as the Linear Data-modelling Score (LDS) or retraining without top influences, without the need for method-specific hyperparameter tuning.

Paper Structure

This paper contains 46 sections, 2 theorems, 45 equations, 22 figures, 5 tables, 2 algorithms.

Key Result

Theorem 1

Let $F : \mathbb{R}^{n}\times\mathbb{R}^m \to \mathbb{R}^m$ be a continuously differentiable function, and let $\mathbb{R}^{n}\times \mathbb{R}^m$ have coordinates $(\mathbf{x}, \mathbf{y})$. Fix a point $(\mathbf{a}, \mathbf{b}) = (a_1, \ldots, a_n, b_1, \ldots, b_m)$ with $F(\mathbf{a}, \mathbf{b} is invertible, then there exists an open set $U \subset \mathbb{R}^n$ containing $\mathbf{a}$ such

Figures (22)

  • Figure 1: Most influential training data points as identified by K-FAC Influence Functions for samples generated by a denoising diffusion probabilistic model trained on ArtBench. The top influences are those whose omission from the training set is predicted to most increase the loss of the generated sample. Negative influences are those predicted to most decrease the loss, and the most neutral are those that should change the loss the least.
  • Figure 2: Linear Data-modelling Score (LDS) for different data attribution methods. Methods that substitute in incorrect measurement functions into the approximation are separated and plotted with . Where applicable, we plot results for both the best Hessian-approximation damping value with and a “default” damping value with . The numerical results are reported in black for the best damping value, and for the “default” damping value in (gray). “ (m. loss)” implies that the appropriate measurement function was substituted with the loss $\ell(\theta, x)$ measurement function in the approximation. Results for the exact retraining method (oracle), are shown with . Standard error in the LDS score estimate is indicated with '$\pm$', where the mean is taken over different generated samples $x$ on which the change in measurement is being estimated.
  • Figure 3: Changes in measurements under counterfactual retraining without top influences for the loss measurement. The standard error in the estimate of the mean is indicated with error bars and reported after ‘$\pm$’, where the average is over different generated samples for which top influences are being identified.
  • Figure 4: Illustration of the influence function approximation for a 1-dimensional parameter space $\theta\in\mathbb{R}$. Influence funcitons consider the extended loss landscape $\mathcal{L}(\varepsilon, \theta )\mathrel{\stackrel{\textnormal{\tiny def}}{=}} \frac{1}{N}\sum_{n=1}^N \ell(x_n, \theta) - \varepsilon \ell(x_j, \theta)$, where the loss $\ell(x_j, \theta)$ for some datapoint $x_j$ (alternatively, group of datapoints) is down-weighted by $\varepsilon$. By linearly extrapolating how the optimal set of parameters $\theta$ would change around $\varepsilon=0$ (), we can predicted how the optimal parameters would change when the term $\ell(x_j, \theta)$ is fully removed from the loss ().
  • Figure 5: Ablation over the different Hessian approximations introduced in \ref{['sec:approx-hessian']} and \ref{['sec:kfac-diffusion-appendix']}. We ablate two versions of the GGN: the “MC” Fisher $\mathrm{GGN}^\texttt{model}$ in \ref{['eq:GGN-diffusion-modelout']} and the “Empirical” Fisher $\mathrm{GGN}^\texttt{loss}$ in \ref{['eq:empirical-fisher']}, as well as multiple settings for the K-FAC approximation: “expand” and “reduce”, and whether we use the eigenvalue-corrected variant (EK-FAC) or not (K-FAC). Same as in \ref{['fig:lds-all-main-figures']}, we report the results for both the best damping value with and a default damping value of $10^{-8}$ with . The damping value ablation for the selection of these results is reported in \ref{['fig:damping-ablation-cifar2-kfac-ablation']}.
  • ...and 17 more figures

Theorems & Definitions (4)

  • Theorem 1: Implicit Function Theorem krantzImplicitFunctionTheorem2003
  • Remark 1: Derivative of the implicit function
  • Lemma 1
  • Remark 2