Table of Contents
Fetching ...

Practical Deep Heteroskedastic Regression

Mikkel Jordahn, Jonas Vestergaard Jensen, James Harrison, Michael Riis Andersen, Mikkel N. Schmidt

TL;DR

This work identifies previously undiscussed fallacies and proposes a simple and efficient procedure that addresses challenges jointly by post-hoc fitting a variance model across the intermediate layers of a pretrained network on a hold-out dataset.

Abstract

Uncertainty quantification (UQ) in deep learning regression is of wide interest, as it supports critical applications including sequential decision making and risk-sensitive tasks. In heteroskedastic regression, where the uncertainty of the target depends on the input, a common approach is to train a neural network that parameterizes the mean and the variance of the predictive distribution. Still, training deep heteroskedastic regression models poses practical challenges in the trade-off between uncertainty quantification and mean prediction, such as optimization difficulties, representation collapse, and variance overfitting. In this work we identify previously undiscussed fallacies and propose a simple and efficient procedure that addresses these challenges jointly by post-hoc fitting a variance model across the intermediate layers of a pretrained network on a hold-out dataset. We demonstrate that our method achieves on-par or state-of-the-art uncertainty quantification on several molecular graph datasets, without compromising mean prediction accuracy and remaining cheap to use at prediction time.

Practical Deep Heteroskedastic Regression

TL;DR

This work identifies previously undiscussed fallacies and proposes a simple and efficient procedure that addresses challenges jointly by post-hoc fitting a variance model across the intermediate layers of a pretrained network on a hold-out dataset.

Abstract

Uncertainty quantification (UQ) in deep learning regression is of wide interest, as it supports critical applications including sequential decision making and risk-sensitive tasks. In heteroskedastic regression, where the uncertainty of the target depends on the input, a common approach is to train a neural network that parameterizes the mean and the variance of the predictive distribution. Still, training deep heteroskedastic regression models poses practical challenges in the trade-off between uncertainty quantification and mean prediction, such as optimization difficulties, representation collapse, and variance overfitting. In this work we identify previously undiscussed fallacies and propose a simple and efficient procedure that addresses these challenges jointly by post-hoc fitting a variance model across the intermediate layers of a pretrained network on a hold-out dataset. We demonstrate that our method achieves on-par or state-of-the-art uncertainty quantification on several molecular graph datasets, without compromising mean prediction accuracy and remaining cheap to use at prediction time.
Paper Structure (46 sections, 13 equations, 18 figures, 6 tables)

This paper contains 46 sections, 13 equations, 18 figures, 6 tables.

Figures (18)

  • Figure 1: Illustrative plot of data where $x_1$ dimension fully explains $\mu(x)$, whilst $x_2$ fully explains $\sigma(x)$. (Left): 3D View of data distributed with $y=ax_1+bx_2^2\epsilon$ with $\epsilon\sim\mathcal{N}(0,1)$. (Middle): Projection of $y$ onto $x_1$ where mean increases but variance is constant. (Right): Projection of $y$ onto $x_2$ where variance of data increases, but mean remains constant. If only estimating mean $a$, it is sufficient to look at only $x_1$ and a basis function learner can thus completely ignore $x_2$.
  • Figure 2: QM9 OOD detection rankings for baseline methods and our post-hoc variance ensemble using the AUROC metric. Rankings are computed for every combination of target and seed and averaged. Error bars indicate the standard error of the mean.
  • Figure 3: QM9 test NLL rankings of fitting the uncertainty estimator using different individual representations $z^{l}$, all representations, or an ensemble of the individual estimators. Rankings are computed for every combination of target and seed and averaged. Error bars indicate one standard error of the mean.
  • Figure 4: The effect of varying the weight decay parameter $\lambda$ when fitting the linear uncertainty estimator using different individual representations $z^{l}$, all representations, or an ensemble of the individual estimators. Results are test NLL averages over three seeds on the QM9 $U$ target and are plotted with one standard error.
  • Figure 5: Effect of varying hold-out dataset size. Results are test NLL on the QM9 $U$ target and are averaged over three seeds. Error bars indicate the standard error of the mean.
  • ...and 13 more figures