Table of Contents
Fetching ...

Lagging Inference Networks and Posterior Collapse in Variational Autoencoders

Junxian He, Daniel Spokoyny, Graham Neubig, Taylor Berg-Kirkpatrick

TL;DR

This paper identifies posterior collapse in VAEs with powerful autoregressive decoders as a consequence of a lagging inference network during early training. It introduces a simple, architecture-free training modification that aggressively optimizes the inference network before each generator update and uses mutual information as a stopping criterion, preserving the ELBO objective. Empirically, the approach mitigates collapse, achieves competitive or superior held-out likelihood on text and image benchmarks, and is faster than prior collapse-avoidance methods. The work highlights training dynamics as a critical factor in VAE performance and offers a practical, efficient remedy with broad applicability to density estimation tasks.

Abstract

The variational autoencoder (VAE) is a popular combination of deep latent variable model and accompanying variational learning technique. By using a neural inference network to approximate the model's posterior on latent variables, VAEs efficiently parameterize a lower bound on marginal data likelihood that can be optimized directly via gradient methods. In practice, however, VAE training often results in a degenerate local optimum known as "posterior collapse" where the model learns to ignore the latent variable and the approximate posterior mimics the prior. In this paper, we investigate posterior collapse from the perspective of training dynamics. We find that during the initial stages of training the inference network fails to approximate the model's true posterior, which is a moving target. As a result, the model is encouraged to ignore the latent encoding and posterior collapse occurs. Based on this observation, we propose an extremely simple modification to VAE training to reduce inference lag: depending on the model's current mutual information between latent variable and observation, we aggressively optimize the inference network before performing each model update. Despite introducing neither new model components nor significant complexity over basic VAE, our approach is able to avoid the problem of collapse that has plagued a large amount of previous work. Empirically, our approach outperforms strong autoregressive baselines on text and image benchmarks in terms of held-out likelihood, and is competitive with more complex techniques for avoiding collapse while being substantially faster.

Lagging Inference Networks and Posterior Collapse in Variational Autoencoders

TL;DR

This paper identifies posterior collapse in VAEs with powerful autoregressive decoders as a consequence of a lagging inference network during early training. It introduces a simple, architecture-free training modification that aggressively optimizes the inference network before each generator update and uses mutual information as a stopping criterion, preserving the ELBO objective. Empirically, the approach mitigates collapse, achieves competitive or superior held-out likelihood on text and image benchmarks, and is faster than prior collapse-avoidance methods. The work highlights training dynamics as a critical factor in VAE performance and offers a practical, efficient remedy with broad applicability to density estimation tasks.

Abstract

The variational autoencoder (VAE) is a popular combination of deep latent variable model and accompanying variational learning technique. By using a neural inference network to approximate the model's posterior on latent variables, VAEs efficiently parameterize a lower bound on marginal data likelihood that can be optimized directly via gradient methods. In practice, however, VAE training often results in a degenerate local optimum known as "posterior collapse" where the model learns to ignore the latent variable and the approximate posterior mimics the prior. In this paper, we investigate posterior collapse from the perspective of training dynamics. We find that during the initial stages of training the inference network fails to approximate the model's true posterior, which is a moving target. As a result, the model is encouraged to ignore the latent encoding and posterior collapse occurs. Based on this observation, we propose an extremely simple modification to VAE training to reduce inference lag: depending on the model's current mutual information between latent variable and observation, we aggressively optimize the inference network before performing each model update. Despite introducing neither new model components nor significant complexity over basic VAE, our approach is able to avoid the problem of collapse that has plagued a large amount of previous work. Empirically, our approach outperforms strong autoregressive baselines on text and image benchmarks in terms of held-out likelihood, and is competitive with more complex techniques for avoiding collapse while being substantially faster.

Paper Structure

This paper contains 30 sections, 6 equations, 7 figures, 10 tables, 1 algorithm.

Figures (7)

  • Figure 1: Left: Depiction of generative model $p({\mathbf{z}})p_{{\bm{\theta}}}({\mathbf{x}}|{\mathbf{z}})$ and inference network $q_{{\bm{\phi}}}({\mathbf{z}}|{\mathbf{x}})$ in VAEs. Right: A toy posterior mean space $(\mu_{{\mathbf{x}}, {\bm{\theta}}}, \mu_{{\mathbf{x}}, {\bm{\phi}}})$ with scalar $z$. The horizontal axis represents the mean of the model posterior $p_{{\bm{\theta}}}({\mathbf{z}}|{\mathbf{x}})$, and the vertical axis represents the mean of the approximate posterior $q_{{\bm{\phi}}}({\mathbf{z}}|{\mathbf{x}})$. The dashed diagonal line represents when the approximate posterior matches the true model posterior in terms of mean.
  • Figure 2: The projections of 500 data samples from a synthetic dataset on the posterior mean space over the course of training. "iter" denotes the number of updates of generators. The top row is from the basic VAE training, the bottom row is from our aggressive inference network training. The results show that while the approximate posterior is lagging far behind the true model posterior in basic VAE training, our aggressive training approach successfully moves the points onto the diagonal line and away from inference collapse.
  • Figure 3: Trajectory of one data instance on the posterior mean space with our aggressive training procedure. Horizontal arrow denotes one step of generator update, and vertical arrow denotes the inner loop of inference network update. We note that the approximate posterior $q_{{\bm{\phi}}}({\mathbf{z}}|{\mathbf{x}})$ takes an aggressive step to catch up to the model posterior $p_{{\bm{\theta}}}({\mathbf{z}}|{\mathbf{x}})$.
  • Figure 4: NLL versus AU (active units) for all models on three datasets. For each model we display 5 points which represent 5 runs with different random seeds. "Autoregressive" denotes LSTM-LM for text data and PixelCNN for image data. We plot "autoregressive" baselines as their AU is 0. To better visualize the system difference on OMNIGLOT dataset, for OMNIGLOT figure we ignore some $\beta$-VAE baselines that are not competitive.
  • Figure 5: Training behavior on Yelp. Left: VAE + annealing. Middle: Our method. Right: $\beta$-VAE ($\beta=0.2$).
  • ...and 2 more figures