Table of Contents
Fetching ...

Transport meets Variational Inference: Controlled Monte Carlo Diffusions

Francisco Vargas, Shreyas Padhy, Denis Blessing, Nikolas Nüsken

TL;DR

This work bridges optimal transport and variational inference by formulating sampling as a divergence on path space between forward and backward diffusion measures $D(\overrightarrow{{\mathbb{P}}}^{\mu,a}||\overleftarrow{{\mathbb{P}}}^{\nu,b})$, and introduces the Controlled Monte Carlo Diffusion (CMCD) sampler for Bayesian computation. CMCD fixes a target density path $\pi_t$ and learns a time-dependent control $\nabla\phi_t$ to realize the forward diffusion ${\mathrm d}{\bm Y}_t=(\sigma^2\nabla\ln\pi_t({\bm Y}_t)+\nabla\phi_t({\bm Y}_t))\,dt+\sigma\sqrt{2}\,\overrightarrow{d}{\bm W}_t$, effectively interpolating from $\pi_0$ to $\pi_T$ while enabling unbiased estimates of the normalising constant via a controlled Crooks identity. The authors connect EM and IPF through Schrödinger bridges, proving that the CMCD objective yields a unique optimal drift and linking the dynamic Schrödinger problem to entropy-regularised transport. Empirically, CMCD achieves state-of-the-art performance on sampling and normalising-constant estimation across multiple benchmarks, demonstrating the practical impact of unifying VI and OT with end-to-end diffusion training. This framework provides a principled, end-to-end approach to entropy-regularised transport and diffusion-based inference, with potential for adaptive annealing strategies and novel divergences.

Abstract

Connecting optimal transport and variational inference, we present a principled and systematic framework for sampling and generative modelling centred around divergences on path space. Our work culminates in the development of the \emph{Controlled Monte Carlo Diffusion} sampler (CMCD) for Bayesian computation, a score-based annealing technique that crucially adapts both forward and backward dynamics in a diffusion model. On the way, we clarify the relationship between the EM-algorithm and iterative proportional fitting (IPF) for Schr{ö}dinger bridges, deriving as well a regularised objective that bypasses the iterative bottleneck of standard IPF-updates. Finally, we show that CMCD has a strong foundation in the Jarzinsky and Crooks identities from statistical physics, and that it convincingly outperforms competing approaches across a wide array of experiments.

Transport meets Variational Inference: Controlled Monte Carlo Diffusions

TL;DR

This work bridges optimal transport and variational inference by formulating sampling as a divergence on path space between forward and backward diffusion measures , and introduces the Controlled Monte Carlo Diffusion (CMCD) sampler for Bayesian computation. CMCD fixes a target density path and learns a time-dependent control to realize the forward diffusion , effectively interpolating from to while enabling unbiased estimates of the normalising constant via a controlled Crooks identity. The authors connect EM and IPF through Schrödinger bridges, proving that the CMCD objective yields a unique optimal drift and linking the dynamic Schrödinger problem to entropy-regularised transport. Empirically, CMCD achieves state-of-the-art performance on sampling and normalising-constant estimation across multiple benchmarks, demonstrating the practical impact of unifying VI and OT with end-to-end diffusion training. This framework provides a principled, end-to-end approach to entropy-regularised transport and diffusion-based inference, with potential for adaptive annealing strategies and novel divergences.

Abstract

Connecting optimal transport and variational inference, we present a principled and systematic framework for sampling and generative modelling centred around divergences on path space. Our work culminates in the development of the \emph{Controlled Monte Carlo Diffusion} sampler (CMCD) for Bayesian computation, a score-based annealing technique that crucially adapts both forward and backward dynamics in a diffusion model. On the way, we clarify the relationship between the EM-algorithm and iterative proportional fitting (IPF) for Schr{ö}dinger bridges, deriving as well a regularised objective that bypasses the iterative bottleneck of standard IPF-updates. Finally, we show that CMCD has a strong foundation in the Jarzinsky and Crooks identities from statistical physics, and that it convincingly outperforms competing approaches across a wide array of experiments.
Paper Structure (50 sections, 11 theorems, 116 equations, 6 figures, 10 tables, 4 algorithms)

This paper contains 50 sections, 11 theorems, 116 equations, 6 figures, 10 tables, 4 algorithms.

Key Result

Proposition 2.1

For $\mu$ and $a$ of sufficient regularity, denote the time-marginals of the corresponding path measure by $\overrightarrow{{\mathbb{P}}}^{\mu,a}_t =:\rho^{\mu,a}_t$. Then $\overrightarrow{{\mathbb{P}}}^{\mu,a} = \overleftarrow{{\mathbb{P}}}^{\nu,b}$ if and only if

Figures (6)

  • Figure 1: Figure panes a) and b) report ELBOs across methods and targets following the experimental setup in geffner2023langevin, the (OD) and (UD) columns group over and under-damped methods seperately. Figure c) reports IS $\ln Z$ estimates and sample quality (where available) using eOT. Higher ELBO and $\ln Z$ denote better estimates, lower ${\mathcal{W}}^\gamma_2$ signifies better sample quality.
  • Figure 2: Architecture from geffner2023langevin used across experiments for our CMCD drift network. Softplus activations are used.
  • Figure 3: (left) 2000 samples drawn from the CMCD algorithm trained with the default loss function, and (right) 2000 samples drawn from the algorithm trained with the log-variance divergence-based loss. We can see that the default loss function misses many modes in the target distribution, whereas the log-variance loss has not missed any modes. We report final results after sweeping over $\Delta_{t_k}$ and learning rates for both methods, picking the one with the lowest training loss. We highlight that concurrent work by richter2023improved explores the log variance divergence in more detail and proposes an akin general framework for diffusion-based sampling.
  • Figure 4: Plots showing training loss curves for the log-variance loss and the default loss for different values of $\Delta_{t_k}$. We find that a low value of $\Delta_{t_k}=0.1$ is needed in order to obtain a low training loss for the default loss, whereas the log-variance loss is much more robust to different values of $\Delta_{t_k}$. The x-axis reports an evaluation every 150 steps of training
  • Figure 5: (\ref{['fig:pinwell']}) our proposed regularised objective, (\ref{['fig:nopinnwell']}) $\lambda$ set to 0 but using clever EM motivated initialisation, (\ref{['fig:badnopinwell']}) $\lambda$ set to 0 with random initialisation of the forward drift, (\ref{['fig:noprior']}) for reference DNF with $f_t = 0$ (uninformative Schrödinger prior).
  • ...and 1 more figures

Theorems & Definitions (36)

  • Proposition 2.1: Nelson's relation
  • Remark 1
  • Proposition 2.2: forward-backward Radon-Nikodym derivatives
  • proof
  • Remark 2: Role of the reference process
  • Remark 3: Discretisation and conversion formulae
  • Proposition 3.1: EM $\iff$ IPF
  • Proposition 3.2: Existence and uniqueness
  • Proposition 3.3: Controlled Crooks' fluctuation theorem and Jarzynki's equality
  • Proposition 3.4: infinitesimal Schrödinger problems
  • ...and 26 more