Table of Contents
Fetching ...

Functional Variational Bayesian Neural Networks

Shengyang Sun, Guodong Zhang, Jiaxin Shi, Roger Grosse

TL;DR

This work reframes Bayesian neural networks as function-space models by placing priors directly over functions and optimizing a functional ELBO (fELBO). A key theoretical result shows the KL divergence between stochastic processes equals the supremum of marginal KLs over finite input sets, enabling practical training via finite measurement sets and the Spectral Stein Gradient Estimator (SSGE) for implicit priors. The authors develop adversarial and sampling-based fVI approaches and demonstrate that fBNNs with structured priors (including Gaussian processes and implicit processes) extrapolate well, provide reliable uncertainty estimates, and scale to large datasets. Empirically, fBNNs outperform weight-space VI baselines on regression benchmarks and excel in contextual bandits and Bayesian optimization, highlighting the benefits of function-space variational inference for uncertainty-aware learning. Overall, this work offers a principled and scalable framework for incorporating rich functional priors into neural models for improved extrapolation and decision-making under uncertainty.

Abstract

Variational Bayesian neural networks (BNNs) perform variational inference over weights, but it is difficult to specify meaningful priors and approximate posteriors in a high-dimensional weight space. We introduce functional variational Bayesian neural networks (fBNNs), which maximize an Evidence Lower BOund (ELBO) defined directly on stochastic processes, i.e. distributions over functions. We prove that the KL divergence between stochastic processes equals the supremum of marginal KL divergences over all finite sets of inputs. Based on this, we introduce a practical training objective which approximates the functional ELBO using finite measurement sets and the spectral Stein gradient estimator. With fBNNs, we can specify priors entailing rich structures, including Gaussian processes and implicit stochastic processes. Empirically, we find fBNNs extrapolate well using various structured priors, provide reliable uncertainty estimates, and scale to large datasets.

Functional Variational Bayesian Neural Networks

TL;DR

This work reframes Bayesian neural networks as function-space models by placing priors directly over functions and optimizing a functional ELBO (fELBO). A key theoretical result shows the KL divergence between stochastic processes equals the supremum of marginal KLs over finite input sets, enabling practical training via finite measurement sets and the Spectral Stein Gradient Estimator (SSGE) for implicit priors. The authors develop adversarial and sampling-based fVI approaches and demonstrate that fBNNs with structured priors (including Gaussian processes and implicit processes) extrapolate well, provide reliable uncertainty estimates, and scale to large datasets. Empirically, fBNNs outperform weight-space VI baselines on regression benchmarks and excel in contextual bandits and Bayesian optimization, highlighting the benefits of function-space variational inference for uncertainty-aware learning. Overall, this work offers a principled and scalable framework for incorporating rich functional priors into neural models for improved extrapolation and decision-making under uncertainty.

Abstract

Variational Bayesian neural networks (BNNs) perform variational inference over weights, but it is difficult to specify meaningful priors and approximate posteriors in a high-dimensional weight space. We introduce functional variational Bayesian neural networks (fBNNs), which maximize an Evidence Lower BOund (ELBO) defined directly on stochastic processes, i.e. distributions over functions. We prove that the KL divergence between stochastic processes equals the supremum of marginal KL divergences over all finite sets of inputs. Based on this, we introduce a practical training objective which approximates the functional ELBO using finite measurement sets and the spectral Stein gradient estimator. With fBNNs, we can specify priors entailing rich structures, including Gaussian processes and implicit stochastic processes. Empirically, we find fBNNs extrapolate well using various structured priors, provide reliable uncertainty estimates, and scale to large datasets.

Paper Structure

This paper contains 44 sections, 7 theorems, 39 equations, 5 figures, 6 tables, 1 algorithm.

Key Result

Theorem 1

For two stochastic processes $P$ and $Q$,

Figures (5)

  • Figure 1: Predictions on the toy funcction $y=x^3$. Here $a \times b$ represents $a$ hidden layers of $b$ units. Red dots are $20$ training points. The blue curve is the mean of final prediction, and the shaded areas represent standard derivations. We compare fBNNs and Bayes-by-Backprop (BBB). For BBB, which performs weight-space inference, varying the network size leads to drastically different predictions. For fBNNs, which perform function-space inference, we observe consistent predictions for the larger networks. Note that the $1\times 100$ factorized Gaussian fBNNs network is not expressive enough to generate diverse predictions.
  • Figure 2: Extrapolating periodic structure. Red dots denote 20 training points. The green and blue lines represent ground truth and mean prediction, respectively. Shaded areas correspond to standard deviations. We considered GP priors with two kernels: RBF (which does not model the periodic structure), and $\text{PER}+\text{RBF}$ (which does). In each case, the fBNN makes similar predictions to the exact GP. In contrast, the standard BBB (BBB-$1$) cannot even fit the training data, while BBB with scaling down KL by $0.001$ (BBB-$0.001$) manages to fit training data, but fails to provide sensible extrapolations.
  • Figure 3: Implicit function priors and fBNN approximate posteriors. The leftmost column shows 3 prior samples. The other three columns show independent runs of the experiment. The red dots denote 40 training samples. We plot 4 posterior samples and show multiples of the predictive standard derivation as shaded areas.
  • Figure 4: Predictions on Mauna datasets. Red dots are training points. The blue line is the mean prediction and shaded areas correspond to standard deviations.
  • Figure 5: Bayesian Optimization. We plot the minimal value found along iterations. We compare fBNN, BBB and Random Feature methods for three kinds of functions corresponding to RBF, Order-$1$ ArcCosine and Matern12 GP kernels. We plot mean and 0.2 standard derivation over 10 independent runs.

Theorems & Definitions (17)

  • Theorem 1
  • Theorem 2: Lower Bound
  • Corollary 3: Consistency under finite measurement points
  • Definition 1: KL divergence
  • Definition 2: Pushforward measure
  • Definition 3: Canonical projection map
  • Definition 4: Cylindrical $\sigma$-algebra
  • Theorem 4: Kolmogorov extension theorem oksendal2003stochastic
  • Theorem 5
  • Definition 5: Replacing function
  • ...and 7 more