Table of Contents
Fetching ...

Theory and Experiments on Vector Quantized Autoencoders

Aurko Roy, Ashish Vaswani, Arvind Neelakantan, Niki Parmar

TL;DR

This work reframes Vector Quantized VAE (VQ-VAE) training in an Expectation Maximization (EM) perspective, showing that hard EM corresponds to a K-means-like assignment of encoder outputs to a code-book, with EMA updates playing the M-step. It then introduces Soft EM with Monte-Carlo approximations to update multiple latent codes per data point, leading to more stable training and better performance. Empirically, EM-enhanced VQ-VAE improves CIFAR-10 image generation and enables a non-autoregressive neural machine translation model that closely matches a greedy autoregressive Transformer while delivering substantial speedups (e.g., about $3.3\times$ faster). The combination of EM-inspired training and knowledge distillation yields BLEU scores around $26.7$ on WMT'14 English–German and shows that principled discrete latent training can reach competitive performance with significant inference-time gains.

Abstract

Deep neural networks with discrete latent variables offer the promise of better symbolic reasoning, and learning abstractions that are more useful to new tasks. There has been a surge in interest in discrete latent variable models, however, despite several recent improvements, the training of discrete latent variable models has remained challenging and their performance has mostly failed to match their continuous counterparts. Recent work on vector quantized autoencoders (VQ-VAE) has made substantial progress in this direction, with its perplexity almost matching that of a VAE on datasets such as CIFAR-10. In this work, we investigate an alternate training technique for VQ-VAE, inspired by its connection to the Expectation Maximization (EM) algorithm. Training the discrete bottleneck with EM helps us achieve better image generation results on CIFAR-10, and together with knowledge distillation, allows us to develop a non-autoregressive machine translation model whose accuracy almost matches a strong greedy autoregressive baseline Transformer, while being 3.3 times faster at inference.

Theory and Experiments on Vector Quantized Autoencoders

TL;DR

This work reframes Vector Quantized VAE (VQ-VAE) training in an Expectation Maximization (EM) perspective, showing that hard EM corresponds to a K-means-like assignment of encoder outputs to a code-book, with EMA updates playing the M-step. It then introduces Soft EM with Monte-Carlo approximations to update multiple latent codes per data point, leading to more stable training and better performance. Empirically, EM-enhanced VQ-VAE improves CIFAR-10 image generation and enables a non-autoregressive neural machine translation model that closely matches a greedy autoregressive Transformer while delivering substantial speedups (e.g., about faster). The combination of EM-inspired training and knowledge distillation yields BLEU scores around on WMT'14 English–German and shows that principled discrete latent training can reach competitive performance with significant inference-time gains.

Abstract

Deep neural networks with discrete latent variables offer the promise of better symbolic reasoning, and learning abstractions that are more useful to new tasks. There has been a surge in interest in discrete latent variable models, however, despite several recent improvements, the training of discrete latent variable models has remained challenging and their performance has mostly failed to match their continuous counterparts. Recent work on vector quantized autoencoders (VQ-VAE) has made substantial progress in this direction, with its perplexity almost matching that of a VAE on datasets such as CIFAR-10. In this work, we investigate an alternate training technique for VQ-VAE, inspired by its connection to the Expectation Maximization (EM) algorithm. Training the discrete bottleneck with EM helps us achieve better image generation results on CIFAR-10, and together with knowledge distillation, allows us to develop a non-autoregressive machine translation model whose accuracy almost matches a strong greedy autoregressive baseline Transformer, while being 3.3 times faster at inference.

Paper Structure

This paper contains 20 sections, 14 equations, 4 figures, 5 tables.

Figures (4)

  • Figure 1: VQ-VAE model as described in vqvae. We use the notation $x$ to denote the input image, with the output of the encoder $z_e(x) \in R^D$ being used to perform nearest neighbor search to select the (sequence of) discrete latent variable. The selected discrete latent is used to train the latent predictor model, while the embedding $z_q(x)$ of the selected discrete latent is passed as input to the decoder.
  • Figure 2: VQ-VAE model adapted to conditional supervised translation as described in kaiser2018fast. We use $x$ and $y$ to denote the source and target sentence respectively. The encoder, the decoder and the latent predictor now additionally condition on the source sentence $x$.
  • Figure 3: Comparison of hard EM (green curve) vs soft EM with different number of samples (yellow and blue curves) on the WMT'14 English-German translation dataset with a code-book size of $2^{14}$, with the encoder of the discrete autoencoder attending to the output of the encoder of the source sentence as in kaiser2018fast. The $y$-axis denotes the teacher-forced BLEU score on the test set. Notice that the hard EM/$K$-means run collapsed, while the soft EM runs exhibit more stability.
  • Figure 4: Samples of original and reconstructed images from CIFAR-10 using VQ-VAE trained using EM with a code-book of size $2^{8}$.