Table of Contents
Fetching ...

Rao-Blackwellised Reparameterisation Gradients

Kevin H. Lam, Thang D. Bui, George Deligiannidis, Yee Whye Teh

TL;DR

The paper addresses gradient estimation for models with latent Gaussian variables by introducing the R2-G2 estimator, a Rao-Blackwellised version of the reparameterisation gradient estimator that reduces variance. It establishes that the local reparameterisation trick is an instance of R2-G2 and provides a practical least-squares formulation with a conjugate gradient solver to compute the Rao-Blackwellised gradient efficiently. The approach yields empirically higher log-likelihoods and ELBOs across Bayesian neural networks and hierarchical VAEs, with consistent variance reduction in gradient estimates, especially when multiple reparameterisations are used. This method broadens the applicability of variance-reduced, single-sample gradient estimators to a wider range of probabilistic models while highlighting the trade-off with computational cost. The work offers a principled, scalable way to initialize and accelerate training in complex latent-variable models, potentially improving performance in pre-training and fine-tuning settings.

Abstract

Latent Gaussian variables have been popularised in probabilistic machine learning. In turn, gradient estimators are the machinery that facilitates gradient-based optimisation for models with latent Gaussian variables. The reparameterisation trick is often used as the default estimator as it is simple to implement and yields low-variance gradients for variational inference. In this work, we propose the R2-G2 estimator as the Rao-Blackwellisation of the reparameterisation gradient estimator. Interestingly, we show that the local reparameterisation gradient estimator for Bayesian MLPs is an instance of the R2-G2 estimator and Rao-Blackwellisation. This lets us extend benefits of Rao-Blackwellised gradients to a suite of probabilistic models. We show that initial training with R2-G2 consistently yields better performance in models with multiple applications of the reparameterisation trick.

Rao-Blackwellised Reparameterisation Gradients

TL;DR

The paper addresses gradient estimation for models with latent Gaussian variables by introducing the R2-G2 estimator, a Rao-Blackwellised version of the reparameterisation gradient estimator that reduces variance. It establishes that the local reparameterisation trick is an instance of R2-G2 and provides a practical least-squares formulation with a conjugate gradient solver to compute the Rao-Blackwellised gradient efficiently. The approach yields empirically higher log-likelihoods and ELBOs across Bayesian neural networks and hierarchical VAEs, with consistent variance reduction in gradient estimates, especially when multiple reparameterisations are used. This method broadens the applicability of variance-reduced, single-sample gradient estimators to a wider range of probabilistic models while highlighting the trade-off with computational cost. The work offers a principled, scalable way to initialize and accelerate training in complex latent-variable models, potentially improving performance in pre-training and fine-tuning settings.

Abstract

Latent Gaussian variables have been popularised in probabilistic machine learning. In turn, gradient estimators are the machinery that facilitates gradient-based optimisation for models with latent Gaussian variables. The reparameterisation trick is often used as the default estimator as it is simple to implement and yields low-variance gradients for variational inference. In this work, we propose the R2-G2 estimator as the Rao-Blackwellisation of the reparameterisation gradient estimator. Interestingly, we show that the local reparameterisation gradient estimator for Bayesian MLPs is an instance of the R2-G2 estimator and Rao-Blackwellisation. This lets us extend benefits of Rao-Blackwellised gradients to a suite of probabilistic models. We show that initial training with R2-G2 consistently yields better performance in models with multiple applications of the reparameterisation trick.

Paper Structure

This paper contains 32 sections, 3 theorems, 40 equations, 11 figures, 5 tables, 2 algorithms.

Key Result

Proposition 4.2

Denote $\mathbf{z} \sim q_{\mathbf{z}} = \mathcal{N}\left(\mathbf{W} \cdot \pmb{\mu}, \mathbf{A} \mathbf{A}^{\top} \right)$. Then we have and

Figures (11)

  • Figure 1: Log average gradient variance v.s. epoch for the top layer of a Bayesian MLP trained on MNIST over 5 runs. We compare the variance of gradients when training using the reparameterisation (RT), local reparameterisation (LRT) and R2-G2 estimators.
  • Figure 2: Bounds on log-likelihood v.s. optimisation steps for a three-layer VAE trained on MNIST over 5 runs. We compare the bounds on log-likelihoods when training using the reparameterisation (RT) and R2-G2 estimators. Training with the R2-G2 estimator improves bounds on log-likelihood on both the training set (left) and test set (right).
  • Figure 3: Log gradient variance v.s. epoch for the bottom layer of a Bayesian MLP trained on MNIST over 5 runs. We compare the variance of gradients when training using the reparameterisation (RT), local reparameterisation (LRT) and R2-G2 estimators.
  • Figure 4: Log gradient variance v.s. epoch for the $8$-th convolutional layer of a Bayesian CNN trained on CIFAR-10 over 5 runs. We compare the variance of gradients when training using the reparameterisation (RT) and R2-G2 estimators.
  • Figure 5: Log gradient variance v.s. epoch for the $5$-th convolutional layer of a Bayesian CNN trained on CIFAR-10 over 5 runs. We compare the variance of gradients when training using the reparameterisation (RT) and R2-G2 estimators.
  • ...and 6 more figures

Theorems & Definitions (4)

  • Definition 4.1: R2-G2
  • Proposition 4.2
  • Theorem 4.3
  • Lemma A.1: eaton1983multivariate