Discrete Variational Autoencoding via Policy Search
Michael Drolet, Firas Al-Hafez, Aditya Bhatt, Jan Peters, Oleg Arenz
TL;DR
The paper addresses learning discrete latent representations in VAEs, where exact reparameterization is unavailable. It introduces Discrete Autoencoding via Policy Search (DAPS), which treats encoder learning as KL-regularized policy search and derives a nonparametric target distribution without backpropagating through sampling, updating the parametric encoder with weighted maximum likelihood and an automatically adjusted trust-region via effective sample size. Using a transformer-based autoregressive encoder and explicit entropy control through a beta regularizer, DAPS achieves superior reconstruction quality on high-dimensional data such as ImageNet-256 and expressive motion data (LAFAN), while remaining scalable and robust across seeds. This approach provides a practical alternative to Gumbel-Softmax and VQ-VAE for discrete latent modeling, with potential benefits for downstream search and control tasks in robotics and beyond.
Abstract
Discrete latent bottlenecks in variational autoencoders (VAEs) offer high bit efficiency and can be modeled with autoregressive discrete distributions, enabling parameter-efficient multimodal search with transformers. However, discrete random variables do not allow for exact differentiable parameterization; therefore, discrete VAEs typically rely on approximations, such as Gumbel-Softmax reparameterization or straight-through gradient estimates, or employ high-variance gradient-free methods such as REINFORCE that have had limited success on high-dimensional tasks such as image reconstruction. Inspired by popular techniques in policy search, we propose a training framework for discrete VAEs that leverages the natural gradient of a non-parametric encoder to update the parametric encoder without requiring reparameterization. Our method, combined with automatic step size adaptation and a transformer-based encoder, scales to challenging datasets such as ImageNet and outperforms both approximate reparameterization methods and quantization-based discrete autoencoders in reconstructing high-dimensional data from compact latent spaces.
