Table of Contents
Fetching ...

Faithful and Efficient Explanations for Neural Networks via Neural Tangent Kernel Surrogate Models

Andrew Engel, Zhichao Wang, Natalie S. Frank, Ioana Dumitriu, Sutanay Choudhury, Anand Sarwate, Tony Chiang

TL;DR

This work defines new approximate eNTK and performs novel analysis on how well the resulting kernel machine surrogate models correlate with the underlying neural network, and concludes that kernel machines using approximate neural tangent kernel as the kernel function are effective surrogate models.

Abstract

A recent trend in explainable AI research has focused on surrogate modeling, where neural networks are approximated as simpler ML algorithms such as kernel machines. A second trend has been to utilize kernel functions in various explain-by-example or data attribution tasks. In this work, we combine these two trends to analyze approximate empirical neural tangent kernels (eNTK) for data attribution. Approximation is critical for eNTK analysis due to the high computational cost to compute the eNTK. We define new approximate eNTK and perform novel analysis on how well the resulting kernel machine surrogate models correlate with the underlying neural network. We introduce two new random projection variants of approximate eNTK which allow users to tune the time and memory complexity of their calculation. We conclude that kernel machines using approximate neural tangent kernel as the kernel function are effective surrogate models, with the introduced trace NTK the most consistent performer. Open source software allowing users to efficiently calculate kernel functions in the PyTorch framework is available (https://github.com/pnnl/projection\_ntk).

Faithful and Efficient Explanations for Neural Networks via Neural Tangent Kernel Surrogate Models

TL;DR

This work defines new approximate eNTK and performs novel analysis on how well the resulting kernel machine surrogate models correlate with the underlying neural network, and concludes that kernel machines using approximate neural tangent kernel as the kernel function are effective surrogate models.

Abstract

A recent trend in explainable AI research has focused on surrogate modeling, where neural networks are approximated as simpler ML algorithms such as kernel machines. A second trend has been to utilize kernel functions in various explain-by-example or data attribution tasks. In this work, we combine these two trends to analyze approximate empirical neural tangent kernels (eNTK) for data attribution. Approximation is critical for eNTK analysis due to the high computational cost to compute the eNTK. We define new approximate eNTK and perform novel analysis on how well the resulting kernel machine surrogate models correlate with the underlying neural network. We introduce two new random projection variants of approximate eNTK which allow users to tune the time and memory complexity of their calculation. We conclude that kernel machines using approximate neural tangent kernel as the kernel function are effective surrogate models, with the introduced trace NTK the most consistent performer. Open source software allowing users to efficiently calculate kernel functions in the PyTorch framework is available (https://github.com/pnnl/projection\_ntk).
Paper Structure (36 sections, 38 equations, 55 figures, 8 tables)

This paper contains 36 sections, 38 equations, 55 figures, 8 tables.

Figures (55)

  • Figure 1: Linear Realization of Bert-base Model. Each panel shows a linearization of a Bert-base transfer model, initialized from a different seed. An invertible mapping is fit between the kGLM and NN to transform the kGLM's final activations to the NN's, described in Appendix \ref{['Appendix:linearization']}. Both $\tau_{K}$ and the Coefficient of Determination ($R^2$) are shown for each model.
  • Figure 2: Overview of Using Kernel Functions for Data Attribution A) An image from the test dataset of CIFAR10 is chosen. B) We propagate the test image through the NN and plot the mean attribution of the training points from each class for each output neuron. C) Zooming into the neuron representing class "dog", we view the distribution of attributions as a modified box-plot with central lines the mean and outliers shown as flier points. The mean lines are always observed to be within the inner quartile, suggesting that no sparse number of datapoints dominate the central value, and therefore, do not dominate the data attribution.
  • Figure 3: Geometric intuition behind the $\mathop{\mathrm{\mathrm{trNTK}}}\nolimits$. A NN function is evaluated at two points creating surfaces $F({\boldsymbol x}_i\,;{\boldsymbol \theta})$ and $F({\boldsymbol x}_j\,;{\boldsymbol \theta})$. These surfaces are shown with a tangent hyper plane at the same point (${\boldsymbol \theta}$) in parameter space coinciding with the end of training. The Jacobian vector defines the tangent hyperplane's orientation in parameter space. The $\mathrm{trNTK}$ is a kernel whose ($i,j$)-th element is the cosine angle between averaged Jacobian vectors. The more similar the local geometry between ${\boldsymbol x}_i$ and ${\boldsymbol x}_j$ local to ${\boldsymbol \theta}$ in parameter-space , the higher the value of $\mathrm{trNTK}(\textcolor{dkgreen}{{\boldsymbol x}_i},\textcolor{blue}{{\boldsymbol x}_j})$.
  • Figure 4: trNTK and $\mathop{\mathrm{\mathrm{proj-trNTK}}}\nolimits$ cosine-similarity residuals fall exponentially. For both ResNet18 Eq. \ref{['fig:residuals1']} and Bert-base Eq. \ref{['fig:residuals2']} we plot the cumulative histogram of residuals between the $\mathop{\mathrm{\mathrm{trNTK}}}\nolimits$ and $\mathop{\mathrm{\mathrm{proj-trNTK}}}\nolimits$. The orange line is an exponential function with k=10240. The orange line is fit "by eye" rather than some best-fit, the objective being to reference the exponential shape of the residual distribution.
  • Figure 5: Distinguishing between independence but high covariance from true dependence Left: Plotting the confidence-confidence scatter plot using two independent models which both have a high probability of correct classification results in a point cloud with high density at (0,0) and (1,1). These point clouds act as anchors that force the Pearson correlation measure to be nearly 1, but because there is no underlying structure the rank-correlation $\tau$ is only 0.5. Right: We visualize the dependent case, which is an ideal form of our surrogate model definition. We see that the anchor point structure is still present forcing the Pearson to be nearly 1, and now the rank correlation $\tau$ has grown to 0.75. Our main point is that Kendall-$\tau$ is not so affected by the issue of separating covariance from dependence as Pearson.
  • ...and 50 more figures