Table of Contents
Fetching ...

Sequential Monte Carlo for Inclusive KL Minimization in Amortized Variational Inference

Declan McNamara, Jackson Loper, Jeffrey Regier

TL;DR

The paper tackles amortized variational inference by minimizing the forward KL divergence, which yields mass-covering posteriors but suffers from biased gradients in traditional wake-sleep methods. It introduces SMC-Wake, a framework that uses likelihood-tempered sequential Monte Carlo (LT-SMC) to estimate forward KL gradients and presents three gradient estimators with asymptotic guarantees, including extensions to particle MCMC. By decoupling proposal generation from the encoder and averaging normalization constant estimates, SMC-Wake avoids the circular pathology that plagues RWS and achieves more accurate posterior approximations in both synthetic and real data. Empirical results across tasks such as two moons, MNIST, Gaussian hierarchies, and galaxy spectra demonstrate substantial improvements in posterior fidelity and stability over wake-phase training, highlighting the practical impact for robust uncertainty quantification in amortized VI.

Abstract

For training an encoder network to perform amortized variational inference, the Kullback-Leibler (KL) divergence from the exact posterior to its approximation, known as the inclusive or forward KL, is an increasingly popular choice of variational objective due to the mass-covering property of its minimizer. However, minimizing this objective is challenging. A popular existing approach, Reweighted Wake-Sleep (RWS), suffers from heavily biased gradients and a circular pathology that results in highly concentrated variational distributions. As an alternative, we propose SMC-Wake, a procedure for fitting an amortized variational approximation that uses likelihood-tempered sequential Monte Carlo samplers to estimate the gradient of the inclusive KL divergence. We propose three gradient estimators, all of which are asymptotically unbiased in the number of iterations and two of which are strongly consistent. Our method interleaves stochastic gradient updates, SMC samplers, and iterative improvement to an estimate of the normalizing constant to reduce bias from self-normalization. In experiments with both simulated and real datasets, SMC-Wake fits variational distributions that approximate the posterior more accurately than existing methods.

Sequential Monte Carlo for Inclusive KL Minimization in Amortized Variational Inference

TL;DR

The paper tackles amortized variational inference by minimizing the forward KL divergence, which yields mass-covering posteriors but suffers from biased gradients in traditional wake-sleep methods. It introduces SMC-Wake, a framework that uses likelihood-tempered sequential Monte Carlo (LT-SMC) to estimate forward KL gradients and presents three gradient estimators with asymptotic guarantees, including extensions to particle MCMC. By decoupling proposal generation from the encoder and averaging normalization constant estimates, SMC-Wake avoids the circular pathology that plagues RWS and achieves more accurate posterior approximations in both synthetic and real data. Empirical results across tasks such as two moons, MNIST, Gaussian hierarchies, and galaxy spectra demonstrate substantial improvements in posterior fidelity and stability over wake-phase training, highlighting the practical impact for robust uncertainty quantification in amortized VI.

Abstract

For training an encoder network to perform amortized variational inference, the Kullback-Leibler (KL) divergence from the exact posterior to its approximation, known as the inclusive or forward KL, is an increasingly popular choice of variational objective due to the mass-covering property of its minimizer. However, minimizing this objective is challenging. A popular existing approach, Reweighted Wake-Sleep (RWS), suffers from heavily biased gradients and a circular pathology that results in highly concentrated variational distributions. As an alternative, we propose SMC-Wake, a procedure for fitting an amortized variational approximation that uses likelihood-tempered sequential Monte Carlo samplers to estimate the gradient of the inclusive KL divergence. We propose three gradient estimators, all of which are asymptotically unbiased in the number of iterations and two of which are strongly consistent. Our method interleaves stochastic gradient updates, SMC samplers, and iterative improvement to an estimate of the normalizing constant to reduce bias from self-normalization. In experiments with both simulated and real datasets, SMC-Wake fits variational distributions that approximate the posterior more accurately than existing methods.
Paper Structure (27 sections, 6 theorems, 29 equations, 14 figures, 3 tables, 5 algorithms)

This paper contains 27 sections, 6 theorems, 29 equations, 14 figures, 3 tables, 5 algorithms.

Key Result

Proposition 1

Let $\mathcal{L}(q)$ denote the surrogate objective defined above for fixed $x$ and fixed $K \in \mathbb{N}$. Let $p$ denote the posterior $p(z \mid x)$. Then there exists $q(z) \neq p(z \mid x)$ such that $\mathcal{L}(q) < \mathcal{L}(p)$.

Figures (14)

  • Figure 1: Posterior approximations given $x_{16}$ by SMC-Wake, Wake, and Defensive Wake. The bottom right panel depicts the exact posterior distribution.
  • Figure 2: Illustration of mass concentration in wake-phase training. Wake-phase reconstructions (middle) of real MNIST digits (top) with label zero all look nearly identical, unlike SMC-Wake reconstructions (bottom).
  • Figure 3: Average forward KL divergence during training for an encoder $q_\phi$ fit using one $K=10,000$ sampler per-point (blue) and $M=100$ different $K=100$ samplers (green).
  • Figure 4: Posterior estimation for one example spectrum $x$ from the training set. We used MCMC (green) and SMC-Wake (blue) to estimate the posterior on $\theta$. Smoothed (marginal) estimates of the posterior for 4 of the 11 parameters $\theta_i$ are plotted, with the true value of the parameter indicated with a red vertical line.
  • Figure 5: Comparison of two-moons variational posteriors.
  • ...and 9 more figures

Theorems & Definitions (8)

  • Proposition 1
  • Proposition 2
  • Proposition 3
  • Proposition 1
  • proof
  • Proposition
  • Proposition 3
  • proof