Learning Energy-based Variational Latent Prior for VAEs
Debottam Dutta, Chaitanya Amballa, Zhongweiyang Xu, Yu-Lin Wei, Romit Roy Choudhury
TL;DR
This work tackles the prior hole problem in VAEs by introducing an energy-based latent prior that can flexibly match the aggregate posterior while preserving fast sampling. EVaLP uses the variational form of the log-normalizing constant to train an EBM prior with an amortized sampler (g_\alpha), enabling MCMC-free training and generation and casting the training as a stable alternating optimization akin to a WGAN with a 1-Lipschitz energy function. The framework extends to hierarchical VAEs and demonstrates improved generation quality (lower FID), reduced prior-hole severity (lower MMD), and faster sampling across datasets like MNIST, CelebA64, and CIFAR-10, with optional SIR for further gains. Overall, EVaLP provides a practical, scalable solution that couples the expressive power of EBMs with efficient latent-space sampling, offering robust priors for high-quality generative models.
Abstract
Variational Auto-Encoders (VAEs) are known to generate blurry and inconsistent samples. One reason for this is the "prior hole" problem. A prior hole refers to regions that have high probability under the VAE's prior but low probability under the VAE's posterior. This means that during data generation, high probability samples from the prior could have low probability under the posterior, resulting in poor quality data. Ideally, a prior needs to be flexible enough to match the posterior while retaining the ability to generate samples fast. Generative models continue to address this tradeoff. This paper proposes to model the prior as an energy-based model (EBM). While EBMs are known to offer the flexibility to match posteriors (and also improving the ELBO), they are traditionally slow in sample generation due to their dependency on MCMC methods. Our key idea is to bring a variational approach to tackle the normalization constant in EBMs, thus bypassing the expensive MCMC approaches. The variational form can be approximated with a sampler network, and we show that such an approach to training priors can be formulated as an alternating optimization problem. Moreover, the same sampler reduces to an implicit variational prior during generation, providing efficient and fast sampling. We compare our Energy-based Variational Latent Prior (EVaLP) method to multiple SOTA baselines and show improvements in image generation quality, reduced prior holes, and better sampling efficiency.
