Table of Contents
Fetching ...

Batch and match: black-box variational inference with a score-based divergence

Diana Cai, Chirag Modi, Loucas Pillaud-Vivien, Charles C. Margossian, Robert M. Gower, David M. Blei, Lawrence K. Saul

TL;DR

Batch and Match (BaM) introduces a score-based divergence $\mathscr{D}(q;p)$ for black-box variational inference, enabling efficient, closed-form proximal updates for Gaussian full-covariance variational families. The method alternates between batch-based score estimation and a closed-form match step, yielding updates for the mean $\mu_t$ and covariance $\Sigma_t$ that converge exponentially to the target under Gaussian assumptions in the infinite-batch limit. Empirical results on Gaussian and non-Gaussian targets, including hierarchical Bayesian models and a deep generative model, show that BaM typically requires fewer gradient evaluations than ELBO-based BBVI and GSM, with faster convergence and greater stability at larger batch sizes. The score-based divergence also provides a natural goodness-of-fit diagnostic and remains applicable to unnormalized targets, highlighting BaM as a robust alternative to traditional VI approaches in high-dimensional, expressive settings.

Abstract

Most leading implementations of black-box variational inference (BBVI) are based on optimizing a stochastic evidence lower bound (ELBO). But such approaches to BBVI often converge slowly due to the high variance of their gradient estimates and their sensitivity to hyperparameters. In this work, we propose batch and match (BaM), an alternative approach to BBVI based on a score-based divergence. Notably, this score-based divergence can be optimized by a closed-form proximal update for Gaussian variational families with full covariance matrices. We analyze the convergence of BaM when the target distribution is Gaussian, and we prove that in the limit of infinite batch size the variational parameter updates converge exponentially quickly to the target mean and covariance. We also evaluate the performance of BaM on Gaussian and non-Gaussian target distributions that arise from posterior inference in hierarchical and deep generative models. In these experiments, we find that BaM typically converges in fewer (and sometimes significantly fewer) gradient evaluations than leading implementations of BBVI based on ELBO maximization.

Batch and match: black-box variational inference with a score-based divergence

TL;DR

Batch and Match (BaM) introduces a score-based divergence for black-box variational inference, enabling efficient, closed-form proximal updates for Gaussian full-covariance variational families. The method alternates between batch-based score estimation and a closed-form match step, yielding updates for the mean and covariance that converge exponentially to the target under Gaussian assumptions in the infinite-batch limit. Empirical results on Gaussian and non-Gaussian targets, including hierarchical Bayesian models and a deep generative model, show that BaM typically requires fewer gradient evaluations than ELBO-based BBVI and GSM, with faster convergence and greater stability at larger batch sizes. The score-based divergence also provides a natural goodness-of-fit diagnostic and remains applicable to unnormalized targets, highlighting BaM as a robust alternative to traditional VI approaches in high-dimensional, expressive settings.

Abstract

Most leading implementations of black-box variational inference (BBVI) are based on optimizing a stochastic evidence lower bound (ELBO). But such approaches to BBVI often converge slowly due to the high variance of their gradient estimates and their sensitivity to hyperparameters. In this work, we propose batch and match (BaM), an alternative approach to BBVI based on a score-based divergence. Notably, this score-based divergence can be optimized by a closed-form proximal update for Gaussian variational families with full covariance matrices. We analyze the convergence of BaM when the target distribution is Gaussian, and we prove that in the limit of infinite batch size the variational parameter updates converge exponentially quickly to the target mean and covariance. We also evaluate the performance of BaM on Gaussian and non-Gaussian target distributions that arise from posterior inference in hierarchical and deep generative models. In these experiments, we find that BaM typically converges in fewer (and sometimes significantly fewer) gradient evaluations than leading implementations of BBVI based on ELBO maximization.
Paper Structure (38 sections, 26 theorems, 153 equations, 14 figures, 1 algorithm)

This paper contains 38 sections, 26 theorems, 153 equations, 14 figures, 1 algorithm.

Key Result

Theorem 3.1

Suppose that $p = {\cal N}(\mu_*,\Sigma_*)$ in alg:LS_GSM_VI, and let $\alpha\!>\!0$ denote the minimum eigenvalue of the matrix $\Sigma_*^{-\frac{1}{2}}\Sigma_0\Sigma_*^{-\frac{1}{2}}$. For any fixed level of regularization $\lambda\!>\!0$, define where $\beta\in(0,1]$ measures the quality of initialization and $\delta\in(0,1)$ denotes a rate of decay. Then with probability 1 in the limit of inf

Figures (14)

  • Figure 5.1: Gaussian targets of increasing dimension. Solid curves indicate the mean over 10 runs (transparent curves). ADVI, Score, Fisher, and GSM use a batch size of $B\!=\!2$. The batch size for BaM is given in the legend.
  • Figure 5.2: Non-Gaussian targets constructed using the sinh-arcsinh distribution, varying the skew $s$ and the tail weight $t$. The curves denote the mean of the forward KL divergence over 10 runs, and shaded regions denote their standard error. ADVI, Score, Fisher, and GSM use a batch size of $B\!=\! 5$.
  • Figure 5.3: Posterior inference in Bayesian models. The curves denote the mean over 5 runs, and shaded regions denote their standard error. Solid curves ($B\!=\!32$) correspond to larger batch sizes than dashed curves ($B\!=\!8$).
  • Figure 5.4: Image reconstruction and error when the posterior mean of $z'$ is fed into the generative neural network. The beige and purple stars highlight the best outcome for ADVI and BaM, respectively, after 3,000 gradient evaluations.
  • Figure D.1: Plot of the function $f$ in eq. (\ref{['eq:f_nu']}), as well as its fixed point and upper and lower bounds from \ref{['lemma-f-mono']}, with $\lambda\!=\!4$ and $\varepsilon^2\!=\!1$.
  • ...and 9 more figures

Theorems & Definitions (53)

  • Theorem 3.1: Exponential convergence
  • proof : Proof Sketch
  • Lemma 1
  • proof
  • Definition 1: Score-based divergence
  • Theorem A.1: Nonnegativity
  • proof
  • Theorem A.2: Affine invariance
  • proof
  • Theorem A.3: Annealing
  • ...and 43 more