Uncertainty Quantification with the Empirical Neural Tangent Kernel
Joseph Wilson, Chris van der Heide, Liam Hodgkinson, Fred Roosta
TL;DR
NUQLS introduces a post-hoc, sampling-based uncertainty quantification method for over-parameterized neural networks by training an ensemble of linearized predictors around the trained parameters via stochastic gradient descent. The key idea is that these linearized models define an ensemble whose predictive distribution matches that of a Gaussian Process with an empirical Neural Tangent Kernel, under mild assumptions, linking NN training dynamics to GP posteriors. Empirically, NUQLS achieves state-of-the-art or competitive uncertainty quantification across regression and classification tasks while being far more computationally efficient than deep ensembles, with robust performance on OoD detection and calibration. This work provides a practical and scalable bridge between neural networks, Gaussian processes, and the NTK, enabling accurate Bayesian uncertainty without extensive retraining or specialized architectures.
Abstract
While neural networks have demonstrated impressive performance across various tasks, accurately quantifying uncertainty in their predictions is essential to ensure their trustworthiness and enable widespread adoption in critical systems. Several Bayesian uncertainty quantification (UQ) methods exist that are either cheap or reliable, but not both. We propose a post-hoc, sampling-based UQ method for over-parameterized networks at the end of training. Our approach constructs efficient and meaningful deep ensembles by employing a (stochastic) gradient-descent sampling process on appropriately linearized networks. We demonstrate that our method effectively approximates the posterior of a Gaussian process using the empirical Neural Tangent Kernel. Through a series of numerical experiments, we show that our method not only outperforms competing approaches in computational efficiency-often reducing costs by multiple factors-but also maintains state-of-the-art performance across a variety of UQ metrics for both regression and classification tasks.
