Fast Neural Tangent Kernel Alignment, Norm and Effective Rank via Trace Estimation
James Hazelden
TL;DR
This work addresses the computational bottleneck of analyzing the empirical finite-width NTK by introducing matrix-free trace estimation methods. It leverages Hutch++ and one-sided Hutchinson-type estimators to rapidly compute NTK-derived metrics, including $tr( ext{NTK})$, $|| ext{NTK}||_F$, $cos( ext{NTK}_1, ext{NTK}_2)$, and $r_{eff}( ext{NTK})$, using only forward- or reverse-mode automatic differentiation. Through extensive numerical validation on an MLP and a GRU, the authors demonstrate substantial speedups (often orders of magnitude) over exact trace calculations while maintaining controllable accuracy, enabling practical NTK analysis for larger architectures. The methods show complementary behavior across architectures and accuracy regimes, with Hutch++ excelling at high accuracy and one-sided estimators providing fast estimates at low sample counts. The work also outlines limitations, potential extensions (e.g., partial traces, NTK product estimators), and provides code for reproducibility and broader applicability in NTK-based analyses.
Abstract
The Neural Tangent Kernel (NTK) characterizes how a model's state evolves over Gradient Descent. Computing the full NTK matrix is often infeasible, especially for recurrent architectures. Here, we introduce a matrix-free perspective, using trace estimation to rapidly analyze the empirical, finite-width NTK. This enables fast computation of the NTK's trace, Frobenius norm, effective rank, and alignment. We provide numerical recipes based on the Hutch++ trace estimator with provably fast convergence guarantees. In addition, we show that, due to the structure of the NTK, one can compute the trace using only forward- or reverse-mode automatic differentiation, not requiring both modes. We show these so-called one-sided estimators can outperform Hutch++ in the low-sample regime, especially when the gap between the model state and parameter count is large. In total, our results demonstrate that matrix-free randomized approaches can yield speedups of many orders of magnitude, leading to faster analysis and applications of the NTK.
