Theory and Experiments on Vector Quantized Autoencoders
Aurko Roy, Ashish Vaswani, Arvind Neelakantan, Niki Parmar
TL;DR
This work reframes Vector Quantized VAE (VQ-VAE) training in an Expectation Maximization (EM) perspective, showing that hard EM corresponds to a K-means-like assignment of encoder outputs to a code-book, with EMA updates playing the M-step. It then introduces Soft EM with Monte-Carlo approximations to update multiple latent codes per data point, leading to more stable training and better performance. Empirically, EM-enhanced VQ-VAE improves CIFAR-10 image generation and enables a non-autoregressive neural machine translation model that closely matches a greedy autoregressive Transformer while delivering substantial speedups (e.g., about $3.3\times$ faster). The combination of EM-inspired training and knowledge distillation yields BLEU scores around $26.7$ on WMT'14 English–German and shows that principled discrete latent training can reach competitive performance with significant inference-time gains.
Abstract
Deep neural networks with discrete latent variables offer the promise of better symbolic reasoning, and learning abstractions that are more useful to new tasks. There has been a surge in interest in discrete latent variable models, however, despite several recent improvements, the training of discrete latent variable models has remained challenging and their performance has mostly failed to match their continuous counterparts. Recent work on vector quantized autoencoders (VQ-VAE) has made substantial progress in this direction, with its perplexity almost matching that of a VAE on datasets such as CIFAR-10. In this work, we investigate an alternate training technique for VQ-VAE, inspired by its connection to the Expectation Maximization (EM) algorithm. Training the discrete bottleneck with EM helps us achieve better image generation results on CIFAR-10, and together with knowledge distillation, allows us to develop a non-autoregressive machine translation model whose accuracy almost matches a strong greedy autoregressive baseline Transformer, while being 3.3 times faster at inference.
