Table of Contents
Fetching ...

Discrete Variational Autoencoding via Policy Search

Michael Drolet, Firas Al-Hafez, Aditya Bhatt, Jan Peters, Oleg Arenz

TL;DR

The paper addresses learning discrete latent representations in VAEs, where exact reparameterization is unavailable. It introduces Discrete Autoencoding via Policy Search (DAPS), which treats encoder learning as KL-regularized policy search and derives a nonparametric target distribution without backpropagating through sampling, updating the parametric encoder with weighted maximum likelihood and an automatically adjusted trust-region via effective sample size. Using a transformer-based autoregressive encoder and explicit entropy control through a beta regularizer, DAPS achieves superior reconstruction quality on high-dimensional data such as ImageNet-256 and expressive motion data (LAFAN), while remaining scalable and robust across seeds. This approach provides a practical alternative to Gumbel-Softmax and VQ-VAE for discrete latent modeling, with potential benefits for downstream search and control tasks in robotics and beyond.

Abstract

Discrete latent bottlenecks in variational autoencoders (VAEs) offer high bit efficiency and can be modeled with autoregressive discrete distributions, enabling parameter-efficient multimodal search with transformers. However, discrete random variables do not allow for exact differentiable parameterization; therefore, discrete VAEs typically rely on approximations, such as Gumbel-Softmax reparameterization or straight-through gradient estimates, or employ high-variance gradient-free methods such as REINFORCE that have had limited success on high-dimensional tasks such as image reconstruction. Inspired by popular techniques in policy search, we propose a training framework for discrete VAEs that leverages the natural gradient of a non-parametric encoder to update the parametric encoder without requiring reparameterization. Our method, combined with automatic step size adaptation and a transformer-based encoder, scales to challenging datasets such as ImageNet and outperforms both approximate reparameterization methods and quantization-based discrete autoencoders in reconstructing high-dimensional data from compact latent spaces.

Discrete Variational Autoencoding via Policy Search

TL;DR

The paper addresses learning discrete latent representations in VAEs, where exact reparameterization is unavailable. It introduces Discrete Autoencoding via Policy Search (DAPS), which treats encoder learning as KL-regularized policy search and derives a nonparametric target distribution without backpropagating through sampling, updating the parametric encoder with weighted maximum likelihood and an automatically adjusted trust-region via effective sample size. Using a transformer-based autoregressive encoder and explicit entropy control through a beta regularizer, DAPS achieves superior reconstruction quality on high-dimensional data such as ImageNet-256 and expressive motion data (LAFAN), while remaining scalable and robust across seeds. This approach provides a practical alternative to Gumbel-Softmax and VQ-VAE for discrete latent modeling, with potential benefits for downstream search and control tasks in robotics and beyond.

Abstract

Discrete latent bottlenecks in variational autoencoders (VAEs) offer high bit efficiency and can be modeled with autoregressive discrete distributions, enabling parameter-efficient multimodal search with transformers. However, discrete random variables do not allow for exact differentiable parameterization; therefore, discrete VAEs typically rely on approximations, such as Gumbel-Softmax reparameterization or straight-through gradient estimates, or employ high-variance gradient-free methods such as REINFORCE that have had limited success on high-dimensional tasks such as image reconstruction. Inspired by popular techniques in policy search, we propose a training framework for discrete VAEs that leverages the natural gradient of a non-parametric encoder to update the parametric encoder without requiring reparameterization. Our method, combined with automatic step size adaptation and a transformer-based encoder, scales to challenging datasets such as ImageNet and outperforms both approximate reparameterization methods and quantization-based discrete autoencoders in reconstructing high-dimensional data from compact latent spaces.

Paper Structure

This paper contains 39 sections, 2 theorems, 33 equations, 32 figures, 2 tables.

Key Result

Lemma 1

Assume $\mathbb{E}_{q_\theta}[w^2] < \infty$, which holds in the case of finite discrete distributions with full support. Then, as $K\to\infty$,

Figures (32)

  • Figure 1: Overview of DAPS. Left: Images are split into patches and embedded by a feed-forward network before entering the encoder. Middle: The decoder generates latent sequences autoregressively: each step conditions on the encoder output and previously sampled latent codes (via causal masking, shown in gray). Right: The generative model decodes latent embeddings into an image.
  • Figure 2: Latent Code Utilization on MNIST Validation Dataset.
  • Figure 3: Validation reconstruction log-likelihoods, $\log p({\bm{x}} | {\bm{z}})$, throughout training ($\beta$ = 0.01).
  • Figure 4: ImageNet 256 validation reconstructions. Top to bottom: DAPS, FSQ, Groundtruth.
  • Figure 5: A latent sequence for hopping is decoded by DAPS and fed through inverse kinematics.
  • ...and 27 more figures

Theorems & Definitions (3)

  • Lemma 1: Population ESS and Rényi--2
  • proof
  • Corollary 1: KL trust region from ESS target