Table of Contents
Fetching ...

Wasserstein Gradient Flow over Variational Parameter Space for Variational Inference

Dai Hai Nguyen, Tetsuya Sakurai, Hiroshi Mamitsuka

TL;DR

This paper addresses variational inference by redefining the optimization domain from latent variables to the variational-parameter space and solving it with Wasserstein gradient flows (WGF). By showing that BBVI and NGVI are special cases of WGF, the authors introduce GFlowVI and NGFlowVI, which represent the variational posterior as a mixture of components updated via preconditioned gradient flows and mirror-descent weight updates. They establish continuous-time descent properties and provide discrete-time, particle-based algorithms, including a simple MD-based mechanism to adapt component weights, and a practical fix for negative Hessians using constrained mirror maps. Empirical results on synthetic distributions and Bayesian neural networks demonstrate faster convergence, improved approximation of multimodal posteriors, and favorable computational scaling compared with kernel-based methods, highlighting the method's flexibility and potential for handling complex variational families.

Abstract

Variational inference (VI) can be cast as an optimization problem in which the variational parameters are tuned to closely align a variational distribution with the true posterior. The optimization task can be approached through vanilla gradient descent in black-box VI or natural-gradient descent in natural-gradient VI. In this work, we reframe VI as the optimization of an objective that concerns probability distributions defined over a \textit{variational parameter space}. Subsequently, we propose Wasserstein gradient descent for tackling this optimization problem. Notably, the optimization techniques, namely black-box VI and natural-gradient VI, can be reinterpreted as specific instances of the proposed Wasserstein gradient descent. To enhance the efficiency of optimization, we develop practical methods for numerically solving the discrete gradient flows. We validate the effectiveness of the proposed methods through empirical experiments on a synthetic dataset, supplemented by theoretical analyses.

Wasserstein Gradient Flow over Variational Parameter Space for Variational Inference

TL;DR

This paper addresses variational inference by redefining the optimization domain from latent variables to the variational-parameter space and solving it with Wasserstein gradient flows (WGF). By showing that BBVI and NGVI are special cases of WGF, the authors introduce GFlowVI and NGFlowVI, which represent the variational posterior as a mixture of components updated via preconditioned gradient flows and mirror-descent weight updates. They establish continuous-time descent properties and provide discrete-time, particle-based algorithms, including a simple MD-based mechanism to adapt component weights, and a practical fix for negative Hessians using constrained mirror maps. Empirical results on synthetic distributions and Bayesian neural networks demonstrate faster convergence, improved approximation of multimodal posteriors, and favorable computational scaling compared with kernel-based methods, highlighting the method's flexibility and potential for handling complex variational families.

Abstract

Variational inference (VI) can be cast as an optimization problem in which the variational parameters are tuned to closely align a variational distribution with the true posterior. The optimization task can be approached through vanilla gradient descent in black-box VI or natural-gradient descent in natural-gradient VI. In this work, we reframe VI as the optimization of an objective that concerns probability distributions defined over a \textit{variational parameter space}. Subsequently, we propose Wasserstein gradient descent for tackling this optimization problem. Notably, the optimization techniques, namely black-box VI and natural-gradient VI, can be reinterpreted as specific instances of the proposed Wasserstein gradient descent. To enhance the efficiency of optimization, we develop practical methods for numerically solving the discrete gradient flows. We validate the effectiveness of the proposed methods through empirical experiments on a synthetic dataset, supplemented by theoretical analyses.
Paper Structure (21 sections, 10 theorems, 79 equations, 3 figures, 8 tables)

This paper contains 21 sections, 10 theorems, 79 equations, 3 figures, 8 tables.

Key Result

Theorem 1

(First variation of $\mathcal{L}(\rho)$). The first variation of $\mathcal{L}(\rho)$ defined in eqn:reformulatedvi is given by: which can be approximated using Monte Carlo samples: where $\mathbold{\lambda}_{k}\sim \rho$, $k=1,2,...,K$ and $\textbf{z}_{i}\sim q(\cdot|\mathbold{\lambda})$, $i=1,...,S$.

Figures (3)

  • Figure 1: Experimental results on the synthetic dataset: (a) the estimated KL divergence in log scale between the target $\pi$ and approximate density $q$ over 1,000 iterations of four updates with $K=10$; (b) performance of NGFlowVI and GFlowVI with varying values of $K$: 1, 3 and 5; (c) visualizations of 1,000 samples from $q$ given by the four updates.
  • Figure 2: Average test negative log-likelihood of Bayesian neural networks (BNNs) on (a) 'Australia scale' and averaged test mean square error of BNNs on (b) 'Boston' and (c) 'Concrete' over 1000 iterations. For SVGD, 100 particles are used, while other methods approximate BNN weight posteriors with a Gaussian mixture ($K=10$). Parameters are updated using WVI-10, NGVI-10, GFlowVI-10 and NGFlowVI-10. Results are averaged over 20 runs of 20 data splits.
  • Figure 3: Experimental results on two synthetic datasets with visualization of 1000 samples from the variational distribution $q$ produced by four methods: NGFlowVI, GFlowVI, NGVI and WVI, using $K=10$ particles. These samples are used to approximate two target distributions: Banana-shaped distribution (the first row) and X-shaped distribution (the second row).

Theorems & Definitions (16)

  • Theorem 1
  • Proposition 2
  • Corollary
  • Proposition 3
  • Corollary
  • Theorem 4
  • Proposition 5
  • proof
  • proof
  • proof
  • ...and 6 more