Variational inference via radial transport
Luca Ghafourpour, Sinho Chewi, Alessio Figalli, Aram-Alexandre Pooladian
TL;DR
The radVI algorithm is a cheap, effective add-on to many existing VI schemes, such as Gaussian (mean-field) VI and Laplace approximation, and provides theoretical convergence guarantees for the algorithm, owing to recent developments over the Wasserstein space.
Abstract
In variational inference (VI), the practitioner approximates a high-dimensional distribution $π$ with a simple surrogate one, often a (product) Gaussian distribution. However, in many cases of practical interest, Gaussian distributions might not capture the correct radial profile of $π$, resulting in poor coverage. In this work, we approach the VI problem from the perspective of optimizing over these radial profiles. Our algorithm radVI is a cheap, effective add-on to many existing VI schemes, such as Gaussian (mean-field) VI and Laplace approximation. We provide theoretical convergence guarantees for our algorithm, owing to recent developments in optimization over the Wasserstein space--the space of probability distributions endowed with the Wasserstein distance--and new regularity properties of radial transport maps in the style of Caffarelli (2000).
