Table of Contents
Fetching ...

Nearly Dimension-Independent Convergence of Mean-Field Black-Box Variational Inference

Kyurae Kim, Yi-An Ma, Trevor Campbell, Jacob R. Gardner

TL;DR

This work provides non-asymptotic convergence guarantees for black-box variational inference (BBVI) with a mean-field location-scale variational family using the reparametrization gradient. For a d-dimensional μ-strongly log-concave and L-log-smooth target, the authors show that BBVI attains an ε-accurate solution in squared distance in O( log d · κ^{2} · ε^{-1}) iterations when the variational base φ has sub-Gaussian tails; heavier-tailed bases yield O( d^{2/k} · κ^{2} · ε^{-1}) with k finite moments. If the target log-density Hessian is constant, the iteration complexity becomes independent of d, i.e., O( κ^{2} ε^{-1}). Central to the analysis is a detailed bound on the gradient variance of the reparametrization gradient, revealing that for mean-field families only one coordinate can exhibit heavy-tailed behavior, while the rest are well-behaved as d grows. The paper also proves a lower bound showing that the gradient-variance bound cannot be improved using only spectral Hessian information. Overall, the results sharply quantify how mean-field BBVI can beat dimension dependence and guide the design of variational families in high dimensions.

Abstract

We prove that, given a mean-field location-scale variational family, black-box variational inference (BBVI) with the reparametrization gradient converges at a rate that is nearly independent of explicit dimension dependence. Specifically, for a $d$-dimensional strongly log-concave and log-smooth target, the number of iterations for BBVI with a sub-Gaussian family to obtain a solution $ε$-close to the global optimum has a dimension dependence of $\mathrm{O}(\log d)$. This is a significant improvement over the $\mathrm{O}(d)$ dependence of full-rank location-scale families. For heavy-tailed families, we prove a weaker $\mathrm{O}(d^{2/k})$ dependence, where $k$ is the number of finite moments of the family. Additionally, if the Hessian of the target log-density is constant, the complexity is free of any explicit dimension dependence. We also prove that our bound on the gradient variance, which is key to our result, cannot be improved using only spectral bounds on the Hessian of the target log-density.

Nearly Dimension-Independent Convergence of Mean-Field Black-Box Variational Inference

TL;DR

This work provides non-asymptotic convergence guarantees for black-box variational inference (BBVI) with a mean-field location-scale variational family using the reparametrization gradient. For a d-dimensional μ-strongly log-concave and L-log-smooth target, the authors show that BBVI attains an ε-accurate solution in squared distance in O( log d · κ^{2} · ε^{-1}) iterations when the variational base φ has sub-Gaussian tails; heavier-tailed bases yield O( d^{2/k} · κ^{2} · ε^{-1}) with k finite moments. If the target log-density Hessian is constant, the iteration complexity becomes independent of d, i.e., O( κ^{2} ε^{-1}). Central to the analysis is a detailed bound on the gradient variance of the reparametrization gradient, revealing that for mean-field families only one coordinate can exhibit heavy-tailed behavior, while the rest are well-behaved as d grows. The paper also proves a lower bound showing that the gradient-variance bound cannot be improved using only spectral Hessian information. Overall, the results sharply quantify how mean-field BBVI can beat dimension dependence and guide the design of variational families in high dimensions.

Abstract

We prove that, given a mean-field location-scale variational family, black-box variational inference (BBVI) with the reparametrization gradient converges at a rate that is nearly independent of explicit dimension dependence. Specifically, for a -dimensional strongly log-concave and log-smooth target, the number of iterations for BBVI with a sub-Gaussian family to obtain a solution -close to the global optimum has a dimension dependence of . This is a significant improvement over the dependence of full-rank location-scale families. For heavy-tailed families, we prove a weaker dependence, where is the number of finite moments of the family. Additionally, if the Hessian of the target log-density is constant, the complexity is free of any explicit dimension dependence. We also prove that our bound on the gradient variance, which is key to our result, cannot be improved using only spectral bounds on the Hessian of the target log-density.

Paper Structure

This paper contains 7 sections, 22 equations.

Theorems & Definitions (6)

  • Definition : Smoothness
  • Definition : Strong Convexity
  • Definition 2.1: Location-Scale Variational Family
  • Definition 2.3: Full-Rank Location-Scale Family
  • Definition 2.4: Mean-Field Location-Scale Family
  • Definition 2.6: Reparametrization Gradient