Table of Contents
Fetching ...

Efficient Analysis of the Distilled Neural Tangent Kernel

Jamie Mahowald, Brian Bell, Alex Ho, Michael Geyer

TL;DR

This work addresses the prohibitive cost of computing NTKs for large neural networks by introducing the Distilled Neural Tangent Kernel (DNTK), a pipeline that fuses dataset distillation, random projection, and gradient distillation to compress data, gradients, and inducing points. By exploiting redundancy at data, parameter, and gradient-subspace levels, DNTK delivers up to several orders of magnitude reductions in computation and storage while preserving kernel structure and predictive accuracy. The approach is underpinned by a theoretical framework linking bilevel dataset distillation to tangent-feature subspaces and a spectral analysis revealing low effective ranks in per-class NTKs. Empirical results on ImageNette with ResNet-18 demonstrate high fidelity with dramatic reductions in required gradients and kernel size, enabling scalable NTK-based analyses for practical networks and datasets.

Abstract

Neural tangent kernel (NTK) methods are computationally limited by the need to evaluate large Jacobians across many data points. Existing approaches reduce this cost primarily through projecting and sketching the Jacobian. We show that NTK computation can also be reduced by compressing the data dimension itself using NTK-tuned dataset distillation. We demonstrate that the neural tangent space spanned by the input data can be induced by dataset distillation, yielding a 20-100$\times$ reduction in required Jacobian calculations. We further show that per-class NTK matrices have low effective rank that is preserved by this reduction. Building on these insights, we propose the distilled neural tangent kernel (DNTK), which combines NTK-tuned dataset distillation with state-of-the-art projection methods to reduce up NTK computational complexity by up to five orders of magnitude while preserving kernel structure and predictive performance.

Efficient Analysis of the Distilled Neural Tangent Kernel

TL;DR

This work addresses the prohibitive cost of computing NTKs for large neural networks by introducing the Distilled Neural Tangent Kernel (DNTK), a pipeline that fuses dataset distillation, random projection, and gradient distillation to compress data, gradients, and inducing points. By exploiting redundancy at data, parameter, and gradient-subspace levels, DNTK delivers up to several orders of magnitude reductions in computation and storage while preserving kernel structure and predictive accuracy. The approach is underpinned by a theoretical framework linking bilevel dataset distillation to tangent-feature subspaces and a spectral analysis revealing low effective ranks in per-class NTKs. Empirical results on ImageNette with ResNet-18 demonstrate high fidelity with dramatic reductions in required gradients and kernel size, enabling scalable NTK-based analyses for practical networks and datasets.

Abstract

Neural tangent kernel (NTK) methods are computationally limited by the need to evaluate large Jacobians across many data points. Existing approaches reduce this cost primarily through projecting and sketching the Jacobian. We show that NTK computation can also be reduced by compressing the data dimension itself using NTK-tuned dataset distillation. We demonstrate that the neural tangent space spanned by the input data can be induced by dataset distillation, yielding a 20-100 reduction in required Jacobian calculations. We further show that per-class NTK matrices have low effective rank that is preserved by this reduction. Building on these insights, we propose the distilled neural tangent kernel (DNTK), which combines NTK-tuned dataset distillation with state-of-the-art projection methods to reduce up NTK computational complexity by up to five orders of magnitude while preserving kernel structure and predictive performance.
Paper Structure (58 sections, 6 theorems, 77 equations, 10 figures, 1 table, 1 algorithm)

This paper contains 58 sections, 6 theorems, 77 equations, 10 figures, 1 table, 1 algorithm.

Key Result

Theorem 3.3

Assume $t \sim \mathcal{T}$, $g_t:=\nabla_\theta \mathcal{L}_t(\theta)$, each $\mathcal{L}_t$ is $L$-smooth, and the realized update is $\theta^+(\tilde{\mathcal{D}}) = \theta-\eta g_{\tilde{\mathcal{D}}}(\theta)$ with $g_{\tilde{\mathcal{D}}}(\theta)\in V(\tilde{\mathcal{D}})$. Fix $\tilde{\mathcal and let $\Delta\theta_t^\star:=\mathop{\mathrm{\operatorname{argmin}}}\limits_{\Delta\theta\in V(\t

Figures (10)

  • Figure 1: Kernel-model accuracy metrics as a function of sample size, (where samples are are taken evenly across classes from the 500 available distilled gradients). Experiments are run on the ImageNette dataset and ResNet-18 model. Test fidelity: fraction of matched predictions between $f_K$ and $f$. Test MSE: computed from predicted logit differences. Test accuracy: correct predictions on an unseen test set. Condition number and minimum eigenvalue: stability of kernel matrices $\tilde{K}^c_{\tilde{X} \tilde{X}}$ averaged across classes. Across all metrics, we find that a pretrained base model results in lower loss and better-conditioned kernel than a distilled-data base model, although the performance differs by 10% if only the distilled-data model is available.
  • Figure 2: Singular values of class kernels reduce exponentially, with truncation ranks between 31 and 41, denoting $(12, 0.05)$- to $(16, 0.05)$-data redundancy, depending on the class.
  • Figure 3: Test metrics (fidelity, accuracy, and MSE) taken from \ref{['fig:size-acc-fid-mse']}. Compression ratio (bottom right) is defined as $m/s$, where $m$ is the number of original gradients and $s$ is the number of gradients distilled by \ref{['alg:local_global_comp']}.
  • Figure 4: Relationships between the spans of local and global eigenvectors across 10 clusters on the "tench" class, whose global truncation rank (at 95% explained variance) is 32. Top: Local eigenvectors $\{\mathbf{u}_i^{j}\}_{j=1}^{r_i}$ project almost entirely onto the subspace spanned by global eigenvectors $\{\mathbf{u}^{(r)}\}_{r=1}^{r_g}$ as rank $r_g$ increases, demonstrating property (A) of \ref{['sec:spectral']}. Curves show the fraction of variance-weighted local eigenvectors contained in the first $r$ global principal components, with maximum, mean, and minimum over clusters approaching 100% near the truncation rank. Middle: Variance decomposition showing which global PCs each cluster uses. Cell $(i,j)$ displays the variance of cluster $i$'s kernel along global PC $j$, computed as $(\mathbf{u}^{j}|_{\mathcal{I}_i})^\top K_i (\mathbf{u}^{j}|_{\mathcal{I}_i})$ normalized by $\text{tr}(K_i)$. Bright regions indicate the global dimensions that explain each cluster's structure. Bottom: Coverage gap demonstrating property (B). For each global PC, the curves show the maximum (dark red) and mean (purple) alignment strength $\| P_i(\mathbf{u}^{j}|_{\mathcal{I}_i}) \|^2$ across all clusters. The orange shaded region represents global variance directions that are poorly covered by any local eigenspace, revealing that roughly $\varepsilon=$12-15% of global structure is not captured by the union of local clusters at the truncation rank. Analogous patterns across all ten classes appear in \ref{['appsec:contain-gaps']}.
  • Figure 5: The same measures of accuracy as in \ref{['fig:size-acc-fid-mse']} saturate quickly with increasing rank. The best rank-$r$ approximation is taken by substituting $U$ and $\Sigma$ with $U^{(r)}$ and $\Sigma^{(r)}$ in \ref{['eqn:woodbury']}.
  • ...and 5 more figures

Theorems & Definitions (21)

  • Definition 3.1: Data redundancy
  • Definition 3.2: Parameter redundancy
  • Theorem 3.3: One-step smoothness regret bound
  • proof
  • Remark 3.4: Coefficient realizability.
  • Corollary 3.5: Competing objectives $\Rightarrow$ PCA subspace of gradient covariance
  • proof
  • Proposition 3.6: Energy-gap decomposition in gradient-feature space
  • proof
  • Remark 4.1
  • ...and 11 more