Table of Contents
Fetching ...

Fast Neural Tangent Kernel Alignment, Norm and Effective Rank via Trace Estimation

James Hazelden

TL;DR

This work addresses the computational bottleneck of analyzing the empirical finite-width NTK by introducing matrix-free trace estimation methods. It leverages Hutch++ and one-sided Hutchinson-type estimators to rapidly compute NTK-derived metrics, including $tr( ext{NTK})$, $|| ext{NTK}||_F$, $cos( ext{NTK}_1, ext{NTK}_2)$, and $r_{eff}( ext{NTK})$, using only forward- or reverse-mode automatic differentiation. Through extensive numerical validation on an MLP and a GRU, the authors demonstrate substantial speedups (often orders of magnitude) over exact trace calculations while maintaining controllable accuracy, enabling practical NTK analysis for larger architectures. The methods show complementary behavior across architectures and accuracy regimes, with Hutch++ excelling at high accuracy and one-sided estimators providing fast estimates at low sample counts. The work also outlines limitations, potential extensions (e.g., partial traces, NTK product estimators), and provides code for reproducibility and broader applicability in NTK-based analyses.

Abstract

The Neural Tangent Kernel (NTK) characterizes how a model's state evolves over Gradient Descent. Computing the full NTK matrix is often infeasible, especially for recurrent architectures. Here, we introduce a matrix-free perspective, using trace estimation to rapidly analyze the empirical, finite-width NTK. This enables fast computation of the NTK's trace, Frobenius norm, effective rank, and alignment. We provide numerical recipes based on the Hutch++ trace estimator with provably fast convergence guarantees. In addition, we show that, due to the structure of the NTK, one can compute the trace using only forward- or reverse-mode automatic differentiation, not requiring both modes. We show these so-called one-sided estimators can outperform Hutch++ in the low-sample regime, especially when the gap between the model state and parameter count is large. In total, our results demonstrate that matrix-free randomized approaches can yield speedups of many orders of magnitude, leading to faster analysis and applications of the NTK.

Fast Neural Tangent Kernel Alignment, Norm and Effective Rank via Trace Estimation

TL;DR

This work addresses the computational bottleneck of analyzing the empirical finite-width NTK by introducing matrix-free trace estimation methods. It leverages Hutch++ and one-sided Hutchinson-type estimators to rapidly compute NTK-derived metrics, including , , , and , using only forward- or reverse-mode automatic differentiation. Through extensive numerical validation on an MLP and a GRU, the authors demonstrate substantial speedups (often orders of magnitude) over exact trace calculations while maintaining controllable accuracy, enabling practical NTK analysis for larger architectures. The methods show complementary behavior across architectures and accuracy regimes, with Hutch++ excelling at high accuracy and one-sided estimators providing fast estimates at low sample counts. The work also outlines limitations, potential extensions (e.g., partial traces, NTK product estimators), and provides code for reproducibility and broader applicability in NTK-based analyses.

Abstract

The Neural Tangent Kernel (NTK) characterizes how a model's state evolves over Gradient Descent. Computing the full NTK matrix is often infeasible, especially for recurrent architectures. Here, we introduce a matrix-free perspective, using trace estimation to rapidly analyze the empirical, finite-width NTK. This enables fast computation of the NTK's trace, Frobenius norm, effective rank, and alignment. We provide numerical recipes based on the Hutch++ trace estimator with provably fast convergence guarantees. In addition, we show that, due to the structure of the NTK, one can compute the trace using only forward- or reverse-mode automatic differentiation, not requiring both modes. We show these so-called one-sided estimators can outperform Hutch++ in the low-sample regime, especially when the gap between the model state and parameter count is large. In total, our results demonstrate that matrix-free randomized approaches can yield speedups of many orders of magnitude, leading to faster analysis and applications of the NTK.

Paper Structure

This paper contains 31 sections, 19 equations, 4 figures, 1 table, 2 algorithms.

Figures (4)

  • Figure 1: Approximating $\text{tr}(\texttt{NTK})$, speedup versus accuracy, summarizing the numerical results in [§\ref{['sec:numres']}]. A corresponds to the MLP explored in [§\ref{['sec:mlp']}] with 3,200 NTK state-variables and around 64,700 parameters. Achieving $99\%$ accuracy was about $80$ times faster than exactly calculating the NTK trace explicitly with $n$ evaluations, while achieving $99.999\%$ accuracy was about $30$ times faster. B corresponds to the GRU in [§\ref{['sec:gru']}]. Reaching $99\%$ accuracy was over 10,000 times faster, while $99.999\%$ could be attained about 70 times faster. Here, each estimator from [§\ref{['sec:methods']}] (Hutch++, RHutch, FHutch) was run 50 times, and we recorded which method reached each accuracy level the fastest, plotting the speedup versus the exact trace evaluation time. The curves show the median across runs and shaded regions show the $25^{\text{th}}$ and $75^{\text{th}}$ percentiles.
  • Figure 2: Estimating $\text{tr}(\texttt{NTK})$ when $h(\theta)$ tracks the output of an MLP neural network.A illustrates the trace estimates for the three approaches in [§\ref{['sec:methods']}]: Hutch++ (Algorithm \ref{['alg:hutchpp_ntk']}) and the one-sided estimator (Algorithm \ref{['alg:oneside_trace']}) with forward- and backward-AD (FHutch and RHutch respectively). B shows the relative error $\frac{|\text{tr}(\texttt{NTK}) - t_{m}|}{|t_m|}$ of each estimate, $t_m$ versus runtime in seconds. For reference, exactly computing the trace took 1.96 seconds. C shows the low-runtime region of B in more detail with linearly scaled runtime, showing that FHutch is the slowest for this setup. Note Hutch++ obtained relative error below $10^{-6}$ in about 0.06 seconds (about $33$ times faster than the exact trace).
  • Figure 3: Estimating $\text{tr}(\texttt{NTK})$ when $h(\theta)$ tracks the hidden unit activations of a GRU neural network at every evaluation time. Consistent with Fig. \ref{['fig:mlp']}, A illustrates the trace estimates for Hutch++ and the one-sided Hutch estimator (Algorithms \ref{['alg:hutchpp_ntk']} and \ref{['alg:oneside_trace']}). B illustrates relative error $\frac{|\text{tr}(\texttt{NTK}) - t_{m}|}{|t_m|}$ of each estimate, $t_m$ versus runtime in seconds. In this case, exactly computing the trace with $n$ matvecs took 1557.72 seconds (around 26 minutes). C shows the low-runtime region of B in more detail with linearly scaled runtime. In contrast to Fig. \ref{['fig:mlp']}, FHutch performed fastest and both one-sided estimators outperformed Hutch++ for error above $2 \cdot 10^{-3}$. Hutch++ was fastest for accuracy above 99.9%.
  • Figure 4: Relative runtime versus error for an MLP trained on MNIST, estimating the Frobenius norm, kernel alignment and the effective rank, respectively.. Here, $\texttt{NTK}_f$ denotes the NTK of the model post-training while $\texttt{NTK}_0$ is the NTK pre-training. A-C Plot the relative error versus the exact value, computed with exact trace evaluation. We plot these versus the percentage of total runtime required to compute the exact expression in each case (range $0\%$ to $100\%$). Curves correspond to the median over 50 re-evaluations with Hutch++ and the shaded regions illustrate the $25^{\text{th}}$ and $75^{\text{th}}$ percentiles.