Table of Contents
Fetching ...

Uncertainty Quantification with the Empirical Neural Tangent Kernel

Joseph Wilson, Chris van der Heide, Liam Hodgkinson, Fred Roosta

TL;DR

NUQLS introduces a post-hoc, sampling-based uncertainty quantification method for over-parameterized neural networks by training an ensemble of linearized predictors around the trained parameters via stochastic gradient descent. The key idea is that these linearized models define an ensemble whose predictive distribution matches that of a Gaussian Process with an empirical Neural Tangent Kernel, under mild assumptions, linking NN training dynamics to GP posteriors. Empirically, NUQLS achieves state-of-the-art or competitive uncertainty quantification across regression and classification tasks while being far more computationally efficient than deep ensembles, with robust performance on OoD detection and calibration. This work provides a practical and scalable bridge between neural networks, Gaussian processes, and the NTK, enabling accurate Bayesian uncertainty without extensive retraining or specialized architectures.

Abstract

While neural networks have demonstrated impressive performance across various tasks, accurately quantifying uncertainty in their predictions is essential to ensure their trustworthiness and enable widespread adoption in critical systems. Several Bayesian uncertainty quantification (UQ) methods exist that are either cheap or reliable, but not both. We propose a post-hoc, sampling-based UQ method for over-parameterized networks at the end of training. Our approach constructs efficient and meaningful deep ensembles by employing a (stochastic) gradient-descent sampling process on appropriately linearized networks. We demonstrate that our method effectively approximates the posterior of a Gaussian process using the empirical Neural Tangent Kernel. Through a series of numerical experiments, we show that our method not only outperforms competing approaches in computational efficiency-often reducing costs by multiple factors-but also maintains state-of-the-art performance across a variety of UQ metrics for both regression and classification tasks.

Uncertainty Quantification with the Empirical Neural Tangent Kernel

TL;DR

NUQLS introduces a post-hoc, sampling-based uncertainty quantification method for over-parameterized neural networks by training an ensemble of linearized predictors around the trained parameters via stochastic gradient descent. The key idea is that these linearized models define an ensemble whose predictive distribution matches that of a Gaussian Process with an empirical Neural Tangent Kernel, under mild assumptions, linking NN training dynamics to GP posteriors. Empirically, NUQLS achieves state-of-the-art or competitive uncertainty quantification across regression and classification tasks while being far more computationally efficient than deep ensembles, with robust performance on OoD detection and calibration. This work provides a practical and scalable bridge between neural networks, Gaussian processes, and the NTK, enabling accurate Bayesian uncertainty without extensive retraining or specialized architectures.

Abstract

While neural networks have demonstrated impressive performance across various tasks, accurately quantifying uncertainty in their predictions is essential to ensure their trustworthiness and enable widespread adoption in critical systems. Several Bayesian uncertainty quantification (UQ) methods exist that are either cheap or reliable, but not both. We propose a post-hoc, sampling-based UQ method for over-parameterized networks at the end of training. Our approach constructs efficient and meaningful deep ensembles by employing a (stochastic) gradient-descent sampling process on appropriately linearized networks. We demonstrate that our method effectively approximates the posterior of a Gaussian process using the empirical Neural Tangent Kernel. Through a series of numerical experiments, we show that our method not only outperforms competing approaches in computational efficiency-often reducing costs by multiple factors-but also maintains state-of-the-art performance across a variety of UQ metrics for both regression and classification tasks.

Paper Structure

This paper contains 60 sections, 3 theorems, 13 equations, 7 figures, 13 tables, 2 algorithms.

Key Result

Lemma 3.1

Suppose the loss, $\ell(\,\cdot\,, {\bf y} )$, is either: The problem eq:ell_range admits a unique solution.

Figures (7)

  • Figure 1: Comparison of various Bayesian UQ methods (see \ref{['sec:background']}) on a 1-layer MLP, trained on the data (red) lying on $y = x^3$ (black), with Gaussian noise added. The methods' mean predictors (blue) $\pm 3\sigma$ (green) are shown, where $\sigma^2$ is the variance estimated via each method. We see that NUQLS performs well on this task.
  • Figure 2: Plot of SEV ($\newmoon$) and NUQLS train loss ($\blacktriangle$) against (top) number of epochs of training for NUQLS and (bottom) number of NUQLS realizations. Bracketed number is condition number of NTK Gram matrix. Mean and $95\%$ confidence intervals are shown from $10$ random realizations.
  • Figure 3: Violin plot of VMSP, for correctly predicted ID test points, incorrectly predicted ID test points, and OoD test points. Median is shown, with violin width depicting density. Low variance is expected for ID correct points, and large variance for ID incorrect and OoD points. Title of plots gives model and dataset used for training.
  • Figure 4: Comparison of BDE, DE and NUQLS on the toy regression problem from \ref{['fig:toy_regression']}. We can see that the uncertainty of the BDE method is quite small.
  • Figure 5: Violin plot of VMSP, with an ensembled version of NUQLS, eNUQLS, included.
  • ...and 2 more figures

Theorems & Definitions (9)

  • Remark 1.1: Necessity of Contribution 4.
  • Lemma 3.1
  • Theorem 3.2
  • Remark 3.3: Connections to GP: Regression
  • Remark 3.4: Connections to GP: General Loss
  • Corollary 3.5: Key Property of NUQLS
  • Remark 3.6
  • proof : Proof of \ref{['lemma:unique_sol']}
  • proof : Proof of \ref{['theorem:gd_genloss']}