Table of Contents
Fetching ...

Understanding Stochastic Natural Gradient Variational Inference

Kaiwen Wu, Jacob R. Gardner

TL;DR

The paper addresses the lack of non-asymptotic convergence guarantees for stochastic NGVI, proving an $O\left(\frac{1}{T}\right)$ rate for conjugate likelihoods that matches stochastic gradient methods up to constants and showing that canonical NGVI can induce a non-convex ELBO for non-conjugate likelihoods. It leverages the NGD/MD equivalence, relative smoothness/convexity concepts, and a data-subsampling framework to bound gradient variance and derive a provable rate. The results elucidate why NGVI often outperforms vanilla SGD in practice (constant factors) and explain the nuanced behavior with non-conjugate likelihoods, supported by Bayesian linear regression and non-conjugate experiments. Overall, the work clarifies the theoretical landscape of NGVI, providing rigorous convergence guarantees in practical settings and guiding future developments in stochastic mirror-descent analyses for variational methods.

Abstract

Stochastic natural gradient variational inference (NGVI) is a popular posterior inference method with applications in various probabilistic models. Despite its wide usage, little is known about the non-asymptotic convergence rate in the \emph{stochastic} setting. We aim to lessen this gap and provide a better understanding. For conjugate likelihoods, we prove the first $\mathcal{O}(\frac{1}{T})$ non-asymptotic convergence rate of stochastic NGVI. The complexity is no worse than stochastic gradient descent (\aka black-box variational inference) and the rate likely has better constant dependency that leads to faster convergence in practice. For non-conjugate likelihoods, we show that stochastic NGVI with the canonical parameterization implicitly optimizes a non-convex objective. Thus, a global convergence rate of $\mathcal{O}(\frac{1}{T})$ is unlikely without some significant new understanding of optimizing the ELBO using natural gradients.

Understanding Stochastic Natural Gradient Variational Inference

TL;DR

The paper addresses the lack of non-asymptotic convergence guarantees for stochastic NGVI, proving an rate for conjugate likelihoods that matches stochastic gradient methods up to constants and showing that canonical NGVI can induce a non-convex ELBO for non-conjugate likelihoods. It leverages the NGD/MD equivalence, relative smoothness/convexity concepts, and a data-subsampling framework to bound gradient variance and derive a provable rate. The results elucidate why NGVI often outperforms vanilla SGD in practice (constant factors) and explain the nuanced behavior with non-conjugate likelihoods, supported by Bayesian linear regression and non-conjugate experiments. Overall, the work clarifies the theoretical landscape of NGVI, providing rigorous convergence guarantees in practical settings and guiding future developments in stochastic mirror-descent analyses for variational methods.

Abstract

Stochastic natural gradient variational inference (NGVI) is a popular posterior inference method with applications in various probabilistic models. Despite its wide usage, little is known about the non-asymptotic convergence rate in the \emph{stochastic} setting. We aim to lessen this gap and provide a better understanding. For conjugate likelihoods, we prove the first non-asymptotic convergence rate of stochastic NGVI. The complexity is no worse than stochastic gradient descent (\aka black-box variational inference) and the rate likely has better constant dependency that leads to faster convergence in practice. For non-conjugate likelihoods, we show that stochastic NGVI with the canonical parameterization implicitly optimizes a non-convex objective. Thus, a global convergence rate of is unlikely without some significant new understanding of optimizing the ELBO using natural gradients.
Paper Structure (28 sections, 19 theorems, 118 equations, 2 figures)

This paper contains 28 sections, 19 theorems, 118 equations, 2 figures.

Key Result

Lemma 1

Suppose the NGD update eq:natural-gradient-update and the MD update eq:mirror-descent-update-proximal-form start from the same variational distribution $q_0$, i.e., $\boldsymbol \eta _0 = \nabla A^*(\boldsymbol \omega _0)$. Then, we have $\boldsymbol \eta _t = \nabla A^*(\boldsymbol

Figures (2)

  • Figure 1: Mini-batch Bayesian linear regression on the Bike dataset. Left: The KL divergence to the optimum $q^*$. Right: The training negative log predictive density.
  • Figure 2: Bayesian logistic regression on Mushroom and MNIST. Labels with "(p)" use the stochastic gradient by the Price theorem \ref{['eq:stochastic-gradient-bonnet-price']}. Labels with "(r)" use the stochastic gradient by the reparamerization trick.

Theorems & Definitions (28)

  • Remark 1: Overload $\ell$
  • Definition 1: FIM
  • Definition 2: NGD
  • Definition 3: Bregman Divergence
  • Definition 4: MD
  • Lemma 1: NGD $=$ MD
  • Definition 5
  • Remark 2
  • Definition 6: hanzely2021fastest
  • Lemma 2
  • ...and 18 more