Table of Contents
Fetching ...

Few-sample Variational Inference of Bayesian Neural Networks with Arbitrary Nonlinearities

David J. Schodt

TL;DR

A simple yet effective approach for propagating statistical moments through arbitrary nonlinearities with only 3 deterministic samples is demonstrated, enabling few-sample variational inference of BNNs without restricting the set of network layers used.

Abstract

Bayesian Neural Networks (BNNs) extend traditional neural networks to provide uncertainties associated with their outputs. On the forward pass through a BNN, predictions (and their uncertainties) are made either by Monte Carlo sampling network weights from the learned posterior or by analytically propagating statistical moments through the network. Though flexible, Monte Carlo sampling is computationally expensive and can be infeasible or impractical under resource constraints or for large networks. While moment propagation can ameliorate the computational costs of BNN inference, it can be difficult or impossible for networks with arbitrary nonlinearities, thereby restricting the possible set of network layers permitted with such a scheme. In this work, we demonstrate a simple yet effective approach for propagating statistical moments through arbitrary nonlinearities with only 3 deterministic samples, enabling few-sample variational inference of BNNs without restricting the set of network layers used. Furthermore, we leverage this approach to demonstrate a novel nonlinear activation function that we use to inject physics-informed prior information into output nodes of a BNN.

Few-sample Variational Inference of Bayesian Neural Networks with Arbitrary Nonlinearities

TL;DR

A simple yet effective approach for propagating statistical moments through arbitrary nonlinearities with only 3 deterministic samples is demonstrated, enabling few-sample variational inference of BNNs without restricting the set of network layers used.

Abstract

Bayesian Neural Networks (BNNs) extend traditional neural networks to provide uncertainties associated with their outputs. On the forward pass through a BNN, predictions (and their uncertainties) are made either by Monte Carlo sampling network weights from the learned posterior or by analytically propagating statistical moments through the network. Though flexible, Monte Carlo sampling is computationally expensive and can be infeasible or impractical under resource constraints or for large networks. While moment propagation can ameliorate the computational costs of BNN inference, it can be difficult or impossible for networks with arbitrary nonlinearities, thereby restricting the possible set of network layers permitted with such a scheme. In this work, we demonstrate a simple yet effective approach for propagating statistical moments through arbitrary nonlinearities with only 3 deterministic samples, enabling few-sample variational inference of BNNs without restricting the set of network layers used. Furthermore, we leverage this approach to demonstrate a novel nonlinear activation function that we use to inject physics-informed prior information into output nodes of a BNN.
Paper Structure (18 sections, 7 equations, 6 figures)

This paper contains 18 sections, 7 equations, 6 figures.

Figures (6)

  • Figure 1: UTVI outperforms MCVI and matches SMP. BNN predictions using (a) SMP, (b) MCVI with 3 samples, and (c) UTVI. All data points represent the average over 10 networks trained independently from distinct random seeds. Note that MCVI with 3 samples appears less noisy than expected due to the averaging over 10 networks (see Figure \ref{['fig:mc_3samples']} for an example output from a single network). "true uncertainty" is defined as $\sqrt{\epsilon(x)^2 + \delta(x)^2}$ where $\epsilon(x)$ is the standard deviation of the simulated noise and $\delta(x)$ is the observed deviation from the prediction to the ground truth. (d) Reconstruction loss after each epoch (negative log-likelihood of in-distribution validation given the trained model) averaged across each of the 10 networks.
  • Figure 2: UTVI is 10X faster than MCVI at similar level of performance. (a) Minimum negative log-likelihood achieved by the best models on the evaluation set for MCVI models trained and evaluated with a varying number of samples. All data points shown were averaged over 10 independently trained models initialized with different random seeds. Error bars indicate $\pm$1 standard deviation over the 10 models. "Best" models were selected by taking the training checkpoint that achieved the lowest negative log-likelihood over the evaluation set. Results from UTVI with 3 samples and SMP (sampling-free) are shown for comparison. (b) Relative evaluation time averaged over 10 batches of 1024 network evaluations on an NVIDIA RTX A1000 laptop GPU. At $2^7$ samples (with which MCVI achieves roughly the same performance as UTVI), inference using UTVI is approximately 10X faster than with MCVI.
  • Figure 3: UTVI accurately predicts variance after nonlinear transformation. (a) CRB (Eqn. \ref{['eqn:crb']}) of emitter position estimates as a function of position across $L \times L$ image. (b-c) Predictive variance of the position inferred by (b) UTVI with 3 sigma points and (c) MCVI with 10 samples for the network architecture described in Section \ref{['sec:architecture']}. Images are scaled independently such that color is proportional to variance.
  • Figure 4: UTVI outperforms MCVI and matches SMP. Example results from a single model (for each of SMP, MCVI, and UTVI) for comparison to the averages over 10 models shown in Figure \ref{['fig:fc_bnn']}. Note that MCVI predictions are noisier than apparent from the model averaged results.
  • Figure 5: Examples of simulated object localization data. Random examples of data generated by the simulator described in Section \ref{['sec:gauss_sim']}. Red dots are added for visualization to indicate the ground truth object position and are not included during network training or evaluation. Images were rescaled independently to improve visualization.
  • ...and 1 more figures