Rao-Blackwellised Reparameterisation Gradients
Kevin H. Lam, Thang D. Bui, George Deligiannidis, Yee Whye Teh
TL;DR
The paper addresses gradient estimation for models with latent Gaussian variables by introducing the R2-G2 estimator, a Rao-Blackwellised version of the reparameterisation gradient estimator that reduces variance. It establishes that the local reparameterisation trick is an instance of R2-G2 and provides a practical least-squares formulation with a conjugate gradient solver to compute the Rao-Blackwellised gradient efficiently. The approach yields empirically higher log-likelihoods and ELBOs across Bayesian neural networks and hierarchical VAEs, with consistent variance reduction in gradient estimates, especially when multiple reparameterisations are used. This method broadens the applicability of variance-reduced, single-sample gradient estimators to a wider range of probabilistic models while highlighting the trade-off with computational cost. The work offers a principled, scalable way to initialize and accelerate training in complex latent-variable models, potentially improving performance in pre-training and fine-tuning settings.
Abstract
Latent Gaussian variables have been popularised in probabilistic machine learning. In turn, gradient estimators are the machinery that facilitates gradient-based optimisation for models with latent Gaussian variables. The reparameterisation trick is often used as the default estimator as it is simple to implement and yields low-variance gradients for variational inference. In this work, we propose the R2-G2 estimator as the Rao-Blackwellisation of the reparameterisation gradient estimator. Interestingly, we show that the local reparameterisation gradient estimator for Bayesian MLPs is an instance of the R2-G2 estimator and Rao-Blackwellisation. This lets us extend benefits of Rao-Blackwellised gradients to a suite of probabilistic models. We show that initial training with R2-G2 consistently yields better performance in models with multiple applications of the reparameterisation trick.
