Table of Contents
Fetching ...

Better Hessians Matter: Studying the Impact of Curvature Approximations in Influence Functions

Steve Hong, Runa Eschenhagen, Bruno Mlodozeniec, Richard Turner

TL;DR

The experiments show that better Hessian approximations consistently yield better influence score quality, offering justification for recent research efforts towards that end and highlighting which approximations are most critical, guiding future efforts to balance computational tractability and attribution accuracy.

Abstract

Influence functions offer a principled way to trace model predictions back to training data, but their use in deep learning is hampered by the need to invert a large, ill-conditioned Hessian matrix. Approximations such as Generalised Gauss-Newton (GGN) and Kronecker-Factored Approximate Curvature (K-FAC) have been proposed to make influence computation tractable, yet it remains unclear how the departure from exactness impacts data attribution performance. Critically, given the restricted regime in which influence functions are derived, it is not necessarily clear better Hessian approximations should even lead to better data attribution performance. In this paper, we investigate the effect of Hessian approximation quality on influence-function attributions in a controlled classification setting. Our experiments show that better Hessian approximations consistently yield better influence score quality, offering justification for recent research efforts towards that end. We further decompose the approximation steps for recent Hessian approximation methods and evaluate each step's influence on attribution accuracy. Notably, the mismatch between K-FAC eigenvalues and GGN/EK-FAC eigenvalues accounts for the majority of the error and influence loss. These findings highlight which approximations are most critical, guiding future efforts to balance computational tractability and attribution accuracy.

Better Hessians Matter: Studying the Impact of Curvature Approximations in Influence Functions

TL;DR

The experiments show that better Hessian approximations consistently yield better influence score quality, offering justification for recent research efforts towards that end and highlighting which approximations are most critical, guiding future efforts to balance computational tractability and attribution accuracy.

Abstract

Influence functions offer a principled way to trace model predictions back to training data, but their use in deep learning is hampered by the need to invert a large, ill-conditioned Hessian matrix. Approximations such as Generalised Gauss-Newton (GGN) and Kronecker-Factored Approximate Curvature (K-FAC) have been proposed to make influence computation tractable, yet it remains unclear how the departure from exactness impacts data attribution performance. Critically, given the restricted regime in which influence functions are derived, it is not necessarily clear better Hessian approximations should even lead to better data attribution performance. In this paper, we investigate the effect of Hessian approximation quality on influence-function attributions in a controlled classification setting. Our experiments show that better Hessian approximations consistently yield better influence score quality, offering justification for recent research efforts towards that end. We further decompose the approximation steps for recent Hessian approximation methods and evaluate each step's influence on attribution accuracy. Notably, the mismatch between K-FAC eigenvalues and GGN/EK-FAC eigenvalues accounts for the majority of the error and influence loss. These findings highlight which approximations are most critical, guiding future efforts to balance computational tractability and attribution accuracy.

Paper Structure

This paper contains 42 sections, 20 equations, 8 figures, 1 table.

Figures (8)

  • Figure 1: Left: Attribution quality vs. Hessian approximation error - Training duration. LDS and approximation error (Equation \ref{['eq: approx error']}); for epoch {10, 100, 1,000}. Setting is fixed at depth $=$ 8 and width $=$ 16; other hyperparameters follow Table \ref{['tab:hyperparameters_digits']}. Right: Error decomposition table: incremental shares along the curvature-approximation path. $\Delta$ES$\%$ denotes Error Share in percentage in the Hessian$\to$K-FAC path and $\Delta$LDS$\%$ denotes the total Hessian$\to$K-FAC LDS percentage change across steps. B-GGN denotes Block-Diagonal GGN.
  • Figure 2: Left: Attribution quality vs. Hessian approximation error - Network depth. LDS and approximation error (Equation \ref{['eq: approx error']}); for depth {1, 4, 8}. Setting is fixed at epoch $=$ 100 and width $=$ 16; other hyperparameters follow Table \ref{['tab:hyperparameters_digits']}. Right: Error decomposition table: incremental shares along the curvature-approximation path. $\Delta$ES$\%$ denotes Error Share in percentage in the Hessian$\to$K-FAC path and $\Delta$LDS$\%$ denotes the total Hessian$\to$K-FAC LDS percentage change across steps. B-GGN denotes Block-Diagonal GGN.
  • Figure 3: Left: Attribution quality vs. Hessian approximation error - Network width. LDS and approximation error (Equation \ref{['eq: approx error']}); for widths $\{32,64,128\}$. Setting is fixed at epoch $=$ 100 and depth $=$ 1; other hyperparameters follow Table \ref{['tab:hyperparameters_digits']}. Right: Error decomposition table: incremental shares along the curvature-approximation path. $\Delta$ES$\%$ denotes Error Share in percentage in the Hessian$\to$K-FAC path and $\Delta$LDS$\%$ denotes the total Hessian$\to$K-FAC LDS percentage change across steps. B-GGN denotes Block-Diagonal GGN.
  • Figure 4: Expected leave-some-out LDS evaluation. We sample $K$ random subsets of the training data, retrain the model $R$ times per subset to average out randomness, and measure the resulting change in the query metric relative to the full-data baseline. We predict each group’s effect by summing per-example attributions, then report the Spearman rank correlation between observed and predicted effects across groups, aggregated over queries with 95% bootstrap confidence intervals.
  • Figure 5: Residual Term Magnitude (Digits). Fractional size of the residual $\mathbf{R}$ relative to the Hessian $\mathbf{H}$ across (left) training epochs, (right) network depth, and (bottom) network width. Lower values means that $\mathbf{G}$ accounts for a larger share of $\mathbf{H}$.
  • ...and 3 more figures