Table of Contents
Fetching ...

Bayesian Online Natural Gradient (BONG)

Matt Jones, Peter Chang, Kevin Murphy

TL;DR

It is proved that this method recovers exact Bayesian inference if the model is conjugate, and empirically that this method outperforms other online VB methods in the non-conjugate setting, such as online learning for neural networks, especially when controlling for computational costs.

Abstract

We propose a novel approach to sequential Bayesian inference based on variational Bayes (VB). The key insight is that, in the online setting, we do not need to add the KL term to regularize to the prior (which comes from the posterior at the previous timestep); instead we can optimize just the expected log-likelihood, performing a single step of natural gradient descent starting at the prior predictive. We prove this method recovers exact Bayesian inference if the model is conjugate. We also show how to compute an efficient deterministic approximation to the VB objective, as well as our simplified objective, when the variational distribution is Gaussian or a sub-family, including the case of a diagonal plus low-rank precision matrix. We show empirically that our method outperforms other online VB methods in the non-conjugate setting, such as online learning for neural networks, especially when controlling for computational costs.

Bayesian Online Natural Gradient (BONG)

TL;DR

It is proved that this method recovers exact Bayesian inference if the model is conjugate, and empirically that this method outperforms other online VB methods in the non-conjugate setting, such as online learning for neural networks, especially when controlling for computational costs.

Abstract

We propose a novel approach to sequential Bayesian inference based on variational Bayes (VB). The key insight is that, in the online setting, we do not need to add the KL term to regularize to the prior (which comes from the posterior at the previous timestep); instead we can optimize just the expected log-likelihood, performing a single step of natural gradient descent starting at the prior predictive. We prove this method recovers exact Bayesian inference if the model is conjugate. We also show how to compute an efficient deterministic approximation to the VB objective, as well as our simplified objective, when the variational distribution is Gaussian or a sub-family, including the case of a diagonal plus low-rank precision matrix. We show empirically that our method outperforms other online VB methods in the non-conjugate setting, such as online learning for neural networks, especially when controlling for computational costs.
Paper Structure (70 sections, 2 theorems, 155 equations, 13 figures, 3 tables, 7 algorithms)

This paper contains 70 sections, 2 theorems, 155 equations, 13 figures, 3 tables, 7 algorithms.

Key Result

Proposition 4.1

Let the observation distribution (likelihood) be an exponential family with natural parameter ${\bm{\theta}}_{t}$ (where $T_l({\bm{y}}_t)={\bm{y}}_t$ is the sufficient statistics for the likelihood and $A({\bm{\theta}}_t)$ is the log-partition function) and let the prior be the conjugate exponential family with $T({\bm{\theta}}_{t}) = \left[{\bm{\theta}}_{t}; -A({\bm{\theta}}_{t})\right]$. Then

Figures (13)

  • Figure 1: Performance on MNIST using Lin-MC posterior predictive, where the posterior is computed using bong, bog, bbb and blr and the 3 tractable Hessian approximations with dlr-10 variational family.
  • Figure 2: Performance on MNIST using Lin-MC posterior predictive, where the posterior is computed using bong with different variational families, namely diagonal (natural and moment), dlr-1, dlr-10.
  • Figure 3: Runtimes for methods on MNIST. Left: Corresponding to \ref{['fig:mnist-main-dlr']} using different algorithms on dlr-10 family. Right: Corresponding to \ref{['fig:mnist-main-bong']}, using bong on different variational families.
  • Figure 4: Running time (seconds) vs number of parameters $P$ (size of state space) on a synthetic regression problem. For bbb and blr, we show results using $I=1$ and $I=10$ iterations per step. Hessian approximations are denoted as follows: EF0-Lin0 = MC-Hess, EF1-Lin0 = EF-Hess, EF0-Lin1 = Lin-Hess. (a) Full Covariance representation. (b) DLR representation. The BLR plot is truncated due to out of memory problem.
  • Figure 5: MNIST results for methods using DLR family. Left column shows misclassification rate, right column showns NLL. First row uses plugin approximation to the posterior predictive, second row uses linearized MC approximation, and third row uses standard MC approximation.
  • ...and 8 more figures

Theorems & Definitions (4)

  • Proposition 4.1
  • Proposition 4.2
  • proof : \ref{['thm:exact-when-conjugate']}
  • proof : \ref{['thm:lbong']}