Table of Contents
Fetching ...

Beyond ReinMax: Low-Variance Gradient Estimators for Discrete Latent Variables

Daniel Wang, Thang D. Bui

TL;DR

The ReinMax-Rao and ReinMax-CV estimators which incorporate Rao-Blackwellisation and control variate techniques into ReinMax to reduce its variance are introduced and demonstrate superior performance on training variational autoencoders with discrete latent spaces.

Abstract

Machine learning models involving discrete latent variables require gradient estimators to facilitate backpropagation in a computationally efficient manner. The most recent addition to the Straight-Through family of estimators, ReinMax, can be viewed from a numerical ODE perspective as incorporating an approximation via Heun's method to reduce bias, but at the cost of high variance. In this work, we introduce the ReinMax-Rao and ReinMax-CV estimators which incorporate Rao-Blackwellisation and control variate techniques into ReinMax to reduce its variance. Our estimators demonstrate superior performance on training variational autoencoders with discrete latent spaces. Furthermore, we investigate the possibility of leveraging alternative numerical methods for constructing more accurate gradient approximations and present an alternative view of ReinMax from a simpler numerical integration perspective.

Beyond ReinMax: Low-Variance Gradient Estimators for Discrete Latent Variables

TL;DR

The ReinMax-Rao and ReinMax-CV estimators which incorporate Rao-Blackwellisation and control variate techniques into ReinMax to reduce its variance are introduced and demonstrate superior performance on training variational autoencoders with discrete latent spaces.

Abstract

Machine learning models involving discrete latent variables require gradient estimators to facilitate backpropagation in a computationally efficient manner. The most recent addition to the Straight-Through family of estimators, ReinMax, can be viewed from a numerical ODE perspective as incorporating an approximation via Heun's method to reduce bias, but at the cost of high variance. In this work, we introduce the ReinMax-Rao and ReinMax-CV estimators which incorporate Rao-Blackwellisation and control variate techniques into ReinMax to reduce its variance. Our estimators demonstrate superior performance on training variational autoencoders with discrete latent spaces. Furthermore, we investigate the possibility of leveraging alternative numerical methods for constructing more accurate gradient approximations and present an alternative view of ReinMax from a simpler numerical integration perspective.
Paper Structure (21 sections, 2 theorems, 30 equations, 2 figures, 4 tables)

This paper contains 21 sections, 2 theorems, 30 equations, 2 figures, 4 tables.

Key Result

Theorem 1.

Figures (2)

  • Figure 1: The sample bias and variance of the estimators over checkpoints of a discrete VAE trained with ReinMax, with 8 categorical dimensions and 4 latent dimensions. The bias is measured using the cosine similarity between the exact gradient and the sample mean of 1024 gradient estimates with a fixed batch of size 100 and fixed model parameters. The variance is simply the sample variance of 1024 gradient estimates. All the temperatures here are 1, except $\tau = 0.1$ for ReinMax-CV.
  • Figure 2: The ELBO on the train set for the $8\times 4$, $8\times16$ and $4\times24$ VAEs after 50 epochs trained with ${\widehat{\nabla}}_{\hbox{\scriptsize ReinMax-RK2}, \beta}$ as a function of $\beta$, spaced evenly from -0.2 to 1.2 in increments of 0.05. The minimum is achieved at approximately $\beta = 0.5$ which corresponds to the original ReinMax estimator. We use $\tau=1$ here.

Theorems & Definitions (2)

  • Theorem 1.
  • Theorem 1