Table of Contents
Fetching ...

Amortising Inference and Meta-Learning Priors in Neural Networks

Tommy Rochussen, Vincent Fortuin

TL;DR

This work tackles the challenge of specifying priors in Bayesian deep learning by meta-learning a neural process whose latent variable is the BNN weights. The Bayesian neural network process (BNNP) enables amortised, layerwise inference via pseudo-observations and introduces the posterior-predictive amortised variational inference (PP-AVI) objective to jointly learn accurate posteriors, faithful priors, and high-quality predictions. It also provides scalable within-task minibatching through sequential Bayesian updates and a mechanism to tune prior flexibility, with attention-based extensions (AttBNNP/BNAM) and a rigorous empirical evaluation showing that meaningful priors can exist and improve performance, albeit not universally. The findings highlight the value of learning priors from diverse datasets to improve Bayesian deep learning and offer practical directions for data-efficient meta-learning of priors, while noting limitations in scalability and consistency that warrant further research.

Abstract

One of the core facets of Bayesianism is in the updating of prior beliefs in light of new evidence$\text{ -- }$so how can we maintain a Bayesian approach if we have no prior beliefs in the first place? This is one of the central challenges in the field of Bayesian deep learning, where it is not clear how to represent beliefs about a prediction task by prior distributions over model parameters. Bridging the fields of Bayesian deep learning and probabilistic meta-learning, we introduce a way to $\textit{learn}$ a weights prior from a collection of datasets by introducing a way to perform per-dataset amortised variational inference. The model we develop can be viewed as a neural process whose latent variable is the set of weights of a BNN and whose decoder is the neural network parameterised by a sample of the latent variable itself. This unique model allows us to study the behaviour of Bayesian neural networks under well-specified priors, use Bayesian neural networks as flexible generative models, and perform desirable but previously elusive feats in neural processes such as within-task minibatching or meta-learning under extreme data-starvation.

Amortising Inference and Meta-Learning Priors in Neural Networks

TL;DR

This work tackles the challenge of specifying priors in Bayesian deep learning by meta-learning a neural process whose latent variable is the BNN weights. The Bayesian neural network process (BNNP) enables amortised, layerwise inference via pseudo-observations and introduces the posterior-predictive amortised variational inference (PP-AVI) objective to jointly learn accurate posteriors, faithful priors, and high-quality predictions. It also provides scalable within-task minibatching through sequential Bayesian updates and a mechanism to tune prior flexibility, with attention-based extensions (AttBNNP/BNAM) and a rigorous empirical evaluation showing that meaningful priors can exist and improve performance, albeit not universally. The findings highlight the value of learning priors from diverse datasets to improve Bayesian deep learning and offer practical directions for data-efficient meta-learning of priors, while noting limitations in scalability and consistency that warrant further research.

Abstract

One of the core facets of Bayesianism is in the updating of prior beliefs in light of new evidenceso how can we maintain a Bayesian approach if we have no prior beliefs in the first place? This is one of the central challenges in the field of Bayesian deep learning, where it is not clear how to represent beliefs about a prediction task by prior distributions over model parameters. Bridging the fields of Bayesian deep learning and probabilistic meta-learning, we introduce a way to a weights prior from a collection of datasets by introducing a way to perform per-dataset amortised variational inference. The model we develop can be viewed as a neural process whose latent variable is the set of weights of a BNN and whose decoder is the neural network parameterised by a sample of the latent variable itself. This unique model allows us to study the behaviour of Bayesian neural networks under well-specified priors, use Bayesian neural networks as flexible generative models, and perform desirable but previously elusive feats in neural processes such as within-task minibatching or meta-learning under extreme data-starvation.
Paper Structure (62 sections, 1 theorem, 39 equations, 13 figures, 2 algorithms)

This paper contains 62 sections, 1 theorem, 39 equations, 13 figures, 2 algorithms.

Key Result

Proposition 1

For $|\Xi|\to\infty$, maximisation of $\mathcal{L}_\text{PP-AVI}(\Xi)$ directly targets the three desiderata.

Figures (13)

  • Figure 1: Computational diagrams of (a) the amortised linear layer, and (b) a BNNP with one hidden layer of activations. We use the context $\cdot_c$ and target $\cdot_t$ notation to distinguish between inputs with labels, on which we condition, and inputs without labels, at which we predict.
  • Figure 2: ELBO and KL divergence between approximate and true posteriors for different VI methods. The BNNP well-approximates the true posterior.
  • Figure 3: Function samples from the true data-generating process (first column), learned BNNP prior predictive samples (second column), BNNP posterior predictive function samples (remaining columns).
  • Figure 4: Generative modelling of MNIST digits with the AttnBNNP. Smaller images depict sampled prior functions evaluated at the same 28$\times$28 grid of inputs that the training data lay on. Larger ones depict sampled functions queried at a 100$\times$100 gid of inputs. The AttnBNNP's prior has encoded the functional behaviour of handwritten digits, so super-resolution is natively supported without needing further training.
  • Figure 5: Demonstration of an ERA5 precipitation prediction test task with the BNNP. With no context points over Switzerland, the BNNP's predictive uncertainty is increased in that region.
  • ...and 8 more figures

Theorems & Definitions (5)

  • Proposition 1
  • Definition 1: accurate approximate posteriors
  • Definition 2: faithful prior
  • Definition 3: high quality predictions.
  • proof