Table of Contents
Fetching ...

Variational inference via radial transport

Luca Ghafourpour, Sinho Chewi, Alessio Figalli, Aram-Alexandre Pooladian

TL;DR

The radVI algorithm is a cheap, effective add-on to many existing VI schemes, such as Gaussian (mean-field) VI and Laplace approximation, and provides theoretical convergence guarantees for the algorithm, owing to recent developments over the Wasserstein space.

Abstract

In variational inference (VI), the practitioner approximates a high-dimensional distribution $π$ with a simple surrogate one, often a (product) Gaussian distribution. However, in many cases of practical interest, Gaussian distributions might not capture the correct radial profile of $π$, resulting in poor coverage. In this work, we approach the VI problem from the perspective of optimizing over these radial profiles. Our algorithm radVI is a cheap, effective add-on to many existing VI schemes, such as Gaussian (mean-field) VI and Laplace approximation. We provide theoretical convergence guarantees for our algorithm, owing to recent developments in optimization over the Wasserstein space--the space of probability distributions endowed with the Wasserstein distance--and new regularity properties of radial transport maps in the style of Caffarelli (2000).

Variational inference via radial transport

TL;DR

The radVI algorithm is a cheap, effective add-on to many existing VI schemes, such as Gaussian (mean-field) VI and Laplace approximation, and provides theoretical convergence guarantees for the algorithm, owing to recent developments over the Wasserstein space.

Abstract

In variational inference (VI), the practitioner approximates a high-dimensional distribution with a simple surrogate one, often a (product) Gaussian distribution. However, in many cases of practical interest, Gaussian distributions might not capture the correct radial profile of , resulting in poor coverage. In this work, we approach the VI problem from the perspective of optimizing over these radial profiles. Our algorithm radVI is a cheap, effective add-on to many existing VI schemes, such as Gaussian (mean-field) VI and Laplace approximation. We provide theoretical convergence guarantees for our algorithm, owing to recent developments in optimization over the Wasserstein space--the space of probability distributions endowed with the Wasserstein distance--and new regularity properties of radial transport maps in the style of Caffarelli (2000).
Paper Structure (47 sections, 14 theorems, 145 equations, 8 figures, 3 tables, 2 algorithms)

This paper contains 47 sections, 14 theorems, 145 equations, 8 figures, 3 tables, 2 algorithms.

Key Result

Theorem 2.1

Suppose $\mu \in \mathcal{P}_{2,\rm{ac}}(\mathbb{R}^d)$. Then def:otmap has a unique minimizer, with and $T^\star = \nabla \varphi^\star$ for some convex function $\varphi^\star$.

Figures (8)

  • Figure 1: Convergence of radVI for various target distributions. See Table \ref{['tab:table1']} for final-iterate comparisons between GVI and LA.
  • Figure 2: Comparing learned radial profiles of radVI versus other approximation methods for learning the Student-$t$ distribution in the isotropic (top) and anisotropic case (bottom).
  • Figure 3: In the case where $\pi$ is an isotropic Gaussian with $d=50$, we verify that radVI is robust to the choice of $\alpha$.
  • Figure 4: Top: Comparing whitening methods for learning the anisotropic logistic distribution, with and without radVI. Bottom: Visual comparison of true target samples, those generated by LA, and ours (LA+radVI).
  • Figure 5: Comparing learned radial profiles of radVI versus other approximation methods for the isotropic Student-$t$ distribution.
  • ...and 3 more figures

Theorems & Definitions (23)

  • Theorem 2.1: Brenier's theorem
  • Remark 3.1
  • Proposition 3.2
  • Proposition 3.3: Stationary condition
  • Proposition 3.4: Regularity of radial minimizer
  • Theorem 3.5: Regularity of the optimal radial map
  • Theorem 4.1: Universal approximation
  • Remark 4.2
  • Theorem 4.3: Convergence of radVI
  • Proposition 4.4
  • ...and 13 more