Table of Contents
Fetching ...

Variation Due to Regularization Tractably Recovers Bayesian Deep Learning

James McInerney, Nathan Kallus

TL;DR

This work addresses epistemic uncertainty in large deep networks by introducing Regularization Variation (RegVar), a method that estimates predictive variance from the change in outputs when a small regularization term is added to the loss. RegVar achieves this without explicit Hessian inversion by defining a prediction-regularized MAP objective and, in its amortized form, scales to many inputs. The authors prove that RegVar recovers the linearized Laplace variance in the infinitesimal limit and demonstrate competitive to state-of-the-art uncertainty quantification methods on large language and vision models, with improvements in calibration and some out-of-distribution settings. These results suggest RegVar as a practical, scalable Bayesian deep learning tool that leverages existing training frameworks for enhanced uncertainty estimation in real-world applications.

Abstract

Uncertainty quantification in deep learning is crucial for safe and reliable decision-making in downstream tasks. Existing methods quantify uncertainty at the last layer or other approximations of the network which may miss some sources of uncertainty in the model. To address this gap, we propose an uncertainty quantification method for large networks based on variation due to regularization. Essentially, predictions that are more (less) sensitive to the regularization of network parameters are less (more, respectively) certain. This principle can be implemented by deterministically tweaking the training loss during the fine-tuning phase and reflects confidence in the output as a function of all layers of the network. We show that regularization variation (RegVar) provides rigorous uncertainty estimates that, in the infinitesimal limit, exactly recover the Laplace approximation in Bayesian deep learning. We demonstrate its success in several deep learning architectures, showing it can scale tractably with the network size while maintaining or improving uncertainty quantification quality. Our experiments across multiple datasets show that RegVar not only identifies uncertain predictions effectively but also provides insights into the stability of learned representations.

Variation Due to Regularization Tractably Recovers Bayesian Deep Learning

TL;DR

This work addresses epistemic uncertainty in large deep networks by introducing Regularization Variation (RegVar), a method that estimates predictive variance from the change in outputs when a small regularization term is added to the loss. RegVar achieves this without explicit Hessian inversion by defining a prediction-regularized MAP objective and, in its amortized form, scales to many inputs. The authors prove that RegVar recovers the linearized Laplace variance in the infinitesimal limit and demonstrate competitive to state-of-the-art uncertainty quantification methods on large language and vision models, with improvements in calibration and some out-of-distribution settings. These results suggest RegVar as a practical, scalable Bayesian deep learning tool that leverages existing training frameworks for enhanced uncertainty estimation in real-world applications.

Abstract

Uncertainty quantification in deep learning is crucial for safe and reliable decision-making in downstream tasks. Existing methods quantify uncertainty at the last layer or other approximations of the network which may miss some sources of uncertainty in the model. To address this gap, we propose an uncertainty quantification method for large networks based on variation due to regularization. Essentially, predictions that are more (less) sensitive to the regularization of network parameters are less (more, respectively) certain. This principle can be implemented by deterministically tweaking the training loss during the fine-tuning phase and reflects confidence in the output as a function of all layers of the network. We show that regularization variation (RegVar) provides rigorous uncertainty estimates that, in the infinitesimal limit, exactly recover the Laplace approximation in Bayesian deep learning. We demonstrate its success in several deep learning architectures, showing it can scale tractably with the network size while maintaining or improving uncertainty quantification quality. Our experiments across multiple datasets show that RegVar not only identifies uncertain predictions effectively but also provides insights into the stability of learned representations.
Paper Structure (24 sections, 2 theorems, 26 equations, 9 figures, 3 tables)

This paper contains 24 sections, 2 theorems, 26 equations, 9 figures, 3 tables.

Key Result

Theorem 4.1

The derivative of the prediction under the prediction-regularized MAP w.r.t. $\lambda$ recovers the variance-covariance term in Eq. eq:bdm, assuming $\hat{\theta}^{(f(x),\lambda)}$ and $P^{-1}$ exist, and that ${\mathcal{L}_\theta}^{(f(x),\lambda)}$ is continuously differentiable w.r.t. $\theta$.

Figures (9)

  • Figure 1: Summary of the regularization-based approach to uncertainty quantification that underlies RegVar. Step 1: fit a predictive model by MAP; step 2: fit another model using an infinitesimally regularized MAP; step 3: the difference between the predictions of these two models, normalized by the amount of regularization, is exactly the estimate of the predictive variance.
  • Figure 2: Example of fine-tuning a model using regularization variation (RegVar). The code required to extend existing fine-tuning to implement RegVar is given in blue. Any auto-differentiation package is sufficient, here we use PyTorch paszke2017automatic.
  • Figure 3: Calibration curves for the methods on the IMDB held-out evaluation with the XL GPT-3 network. Ideal calibration is plotted with dashed lines. 95% confidence intervals are shown with horizontal markers.
  • Figure 4: Contour plot of held-out negative log likelihood ($\downarrow$ is better) for Wikitext validation as function of the $\lambda$ used to train RegVar and variance rescaling used during inference. Lines isometric to the identity line are shown as gray dashes.
  • Figure 5: Fine-tune training and inference times for the methods on IMDB dataset. N.B. ensemble was prohibited in the larger model due to memory constraints.
  • ...and 4 more figures

Theorems & Definitions (6)

  • Theorem 4.1
  • proof
  • Theorem 4.2
  • proof
  • proof
  • proof