Table of Contents
Fetching ...

BALI: Learning Neural Networks via Bayesian Layerwise Inference

Richard Kurle, Alexej Klushyn, Ralf Herbrich

TL;DR

A new method for learning Bayesian neural networks, treating them as a stack of multivariate Bayesian linear regression models, which converges in few iterations and performs as well as or better than leading Bayesian neural network methods on various regression, classification, and out-of-distribution detection benchmarks.

Abstract

We introduce a new method for learning Bayesian neural networks, treating them as a stack of multivariate Bayesian linear regression models. The main idea is to infer the layerwise posterior exactly if we know the target outputs of each layer. We define these pseudo-targets as the layer outputs from the forward pass, updated by the backpropagated gradients of the objective function. The resulting layerwise posterior is a matrix-normal distribution with a Kronecker-factorized covariance matrix, which can be efficiently inverted. Our method extends to the stochastic mini-batch setting using an exponential moving average over natural-parameter terms, thus gradually forgetting older data. The method converges in few iterations and performs as well as or better than leading Bayesian neural network methods on various regression, classification, and out-of-distribution detection benchmarks.

BALI: Learning Neural Networks via Bayesian Layerwise Inference

TL;DR

A new method for learning Bayesian neural networks, treating them as a stack of multivariate Bayesian linear regression models, which converges in few iterations and performs as well as or better than leading Bayesian neural network methods on various regression, classification, and out-of-distribution detection benchmarks.

Abstract

We introduce a new method for learning Bayesian neural networks, treating them as a stack of multivariate Bayesian linear regression models. The main idea is to infer the layerwise posterior exactly if we know the target outputs of each layer. We define these pseudo-targets as the layer outputs from the forward pass, updated by the backpropagated gradients of the objective function. The resulting layerwise posterior is a matrix-normal distribution with a Kronecker-factorized covariance matrix, which can be efficiently inverted. Our method extends to the stochastic mini-batch setting using an exponential moving average over natural-parameter terms, thus gradually forgetting older data. The method converges in few iterations and performs as well as or better than leading Bayesian neural network methods on various regression, classification, and out-of-distribution detection benchmarks.

Paper Structure

This paper contains 45 sections, 71 equations, 7 figures, 7 tables, 5 algorithms.

Figures (7)

  • Figure 1: Posterior-predictive of BALI on synthetic regression and classification datasets (cf. Sec. \ref{['sec:experiments:toy']}).
  • Figure 2: NLL averaged over test set (solid line) and training set (dashed line), plotted over iterations.
  • Figure 3: OOD detection ROC curves using the negative entropy of the predictive distribution to distinguish ID from OOD. The model is trained on the ID dataset and the respective other set is used as OOD data. True positives are ID data classified as ID, false positives are ID data classified as OOD.
  • Figure 4: Bayes By Backprop posterior-predictive distribution on the sines-trend dataset (cf. Sec. \ref{['sec:experiments:toy']}).
  • Figure 5: Bayes By Backprop posterior-predictive distribution on the sinc dataset (cf. Sec. \ref{['sec:experiments:toy']}).
  • ...and 2 more figures