Table of Contents
Fetching ...

Weighted Stochastic Differential Equation to Implement Wasserstein-Fisher-Rao Gradient Flow

Herlock Rahimi

TL;DR

The paper develops a framework to implement Wasserstein–Fisher–Rao gradient flows for diffusion-based samplers by introducing explicit Fisher–Rao corrections via weighted SDEs. It systematically builds the WFR geometry by detailing Wasserstein, Fisher–Rao, and their hybrid HK/WFR structures, and shows how PDEs for probability densities can be realized as SDEs with reweighting. The approach preserves diffusion-level stochastic calculus while enabling transport and mass-change evolution, potentially mitigating metastability in multimodal targets. The work also highlights the equivalence to jump-process formulations and provides a Feynman–Kac representation for weighting, establishing a theoretical foundation for future algorithmic development. It lays groundwork for analyzing curvature and spectral properties to quantify convergence improvements.

Abstract

Score-based diffusion models currently constitute the state of the art in continuous generative modeling. These methods are typically formulated via overdamped or underdamped Ornstein--Uhlenbeck-type stochastic differential equations, in which sampling is driven by a combination of deterministic drift and Brownian diffusion, resulting in continuous particle trajectories in the ambient space. While such dynamics enjoy exponential convergence guarantees for strongly log-concave target distributions, it is well known that their mixing rates deteriorate exponentially in the presence of nonconvex or multimodal landscapes, such as double-well potentials. Since many practical generative modeling tasks involve highly non-log-concave target distributions, considerable recent effort has been devoted to developing sampling schemes that improve exploration beyond classical diffusion dynamics. A promising line of work leverages tools from information geometry to augment diffusion-based samplers with controlled mass reweighting mechanisms. This perspective leads naturally to Wasserstein--Fisher--Rao (WFR) geometries, which couple transport in the sample space with vertical (reaction) dynamics on the space of probability measures. In this work, we formulate such reweighting mechanisms through the introduction of explicit correction terms and show how they can be implemented via weighted stochastic differential equations using the Feynman--Kac representation. Our study provides a preliminary but rigorous investigation of WFR-based sampling dynamics, and aims to clarify their geometric and operator-theoretic structure as a foundation for future theoretical and algorithmic developments.

Weighted Stochastic Differential Equation to Implement Wasserstein-Fisher-Rao Gradient Flow

TL;DR

The paper develops a framework to implement Wasserstein–Fisher–Rao gradient flows for diffusion-based samplers by introducing explicit Fisher–Rao corrections via weighted SDEs. It systematically builds the WFR geometry by detailing Wasserstein, Fisher–Rao, and their hybrid HK/WFR structures, and shows how PDEs for probability densities can be realized as SDEs with reweighting. The approach preserves diffusion-level stochastic calculus while enabling transport and mass-change evolution, potentially mitigating metastability in multimodal targets. The work also highlights the equivalence to jump-process formulations and provides a Feynman–Kac representation for weighting, establishing a theoretical foundation for future algorithmic development. It lays groundwork for analyzing curvature and spectral properties to quantify convergence improvements.

Abstract

Score-based diffusion models currently constitute the state of the art in continuous generative modeling. These methods are typically formulated via overdamped or underdamped Ornstein--Uhlenbeck-type stochastic differential equations, in which sampling is driven by a combination of deterministic drift and Brownian diffusion, resulting in continuous particle trajectories in the ambient space. While such dynamics enjoy exponential convergence guarantees for strongly log-concave target distributions, it is well known that their mixing rates deteriorate exponentially in the presence of nonconvex or multimodal landscapes, such as double-well potentials. Since many practical generative modeling tasks involve highly non-log-concave target distributions, considerable recent effort has been devoted to developing sampling schemes that improve exploration beyond classical diffusion dynamics. A promising line of work leverages tools from information geometry to augment diffusion-based samplers with controlled mass reweighting mechanisms. This perspective leads naturally to Wasserstein--Fisher--Rao (WFR) geometries, which couple transport in the sample space with vertical (reaction) dynamics on the space of probability measures. In this work, we formulate such reweighting mechanisms through the introduction of explicit correction terms and show how they can be implemented via weighted stochastic differential equations using the Feynman--Kac representation. Our study provides a preliminary but rigorous investigation of WFR-based sampling dynamics, and aims to clarify their geometric and operator-theoretic structure as a foundation for future theoretical and algorithmic developments.

Paper Structure

This paper contains 34 sections, 10 theorems, 258 equations, 1 figure, 5 tables.

Key Result

Theorem 8

The Weighted SDE, and sampling scheme introduced above, would realize the generalized pde,i.e.

Figures (1)

  • Figure 1: Geodesic structure and median trajectories across different geometries in the $(\mu,\sigma)$ parameter space. Each panel depicts a triangle formed by three distributions $(p,u,v)$, where the edges connecting $p$ to $u$ and $p$ to $v$ are constructed using a fixed geometry indexed by $i\in\{1,2,3,4\}$, while the edge connecting $u$ and $v$ is generated using a (possibly different) geometry indexed by $j\in\{1,2,3,4\}$. Type 1 corresponds to Wasserstein geodesics (blue), Type 2 to linear mixture geodesics (red), Type 3 to exponential geodesics (green), and Type 4 to Fisher--Rao geodesics (purple). Gray segments indicate collapse trajectories toward the reference distribution $p$ under the geometry indexed by $i$, and the black curve denotes the induced median trajectory between $u$ and $v$, obtained by projecting the $u$--$v$ geodesic through $p$ according to the pair $(i,j)$. Dashed curves represent the direct $u-v$ geodesic in geometry $j$, while solid colored curves illustrate the lifted or corrected paths resulting from the interaction between the two geometric structures. The collection of panels highlights how mismatches between transport and information-geometric structures modify both geodesic shapes and the resulting median trajectories.

Theorems & Definitions (29)

  • Definition 1: Metric derivative
  • Definition 2: Continuity equation
  • Definition 3: Wasserstein tangent space and metric
  • Definition 4: Wasserstein gradient
  • Definition 5: Fisher--Rao metric on $\mathcal{M}_+(\Omega)$
  • Definition 6: Wasserstein--Fisher--Rao tangent space
  • Definition 7: WFR gradient
  • Theorem 8
  • Definition 9: Statistical manifold
  • Lemma 1: Diffusion can be written as drift
  • ...and 19 more