Weighted Stochastic Differential Equation to Implement Wasserstein-Fisher-Rao Gradient Flow
Herlock Rahimi
TL;DR
The paper develops a framework to implement Wasserstein–Fisher–Rao gradient flows for diffusion-based samplers by introducing explicit Fisher–Rao corrections via weighted SDEs. It systematically builds the WFR geometry by detailing Wasserstein, Fisher–Rao, and their hybrid HK/WFR structures, and shows how PDEs for probability densities can be realized as SDEs with reweighting. The approach preserves diffusion-level stochastic calculus while enabling transport and mass-change evolution, potentially mitigating metastability in multimodal targets. The work also highlights the equivalence to jump-process formulations and provides a Feynman–Kac representation for weighting, establishing a theoretical foundation for future algorithmic development. It lays groundwork for analyzing curvature and spectral properties to quantify convergence improvements.
Abstract
Score-based diffusion models currently constitute the state of the art in continuous generative modeling. These methods are typically formulated via overdamped or underdamped Ornstein--Uhlenbeck-type stochastic differential equations, in which sampling is driven by a combination of deterministic drift and Brownian diffusion, resulting in continuous particle trajectories in the ambient space. While such dynamics enjoy exponential convergence guarantees for strongly log-concave target distributions, it is well known that their mixing rates deteriorate exponentially in the presence of nonconvex or multimodal landscapes, such as double-well potentials. Since many practical generative modeling tasks involve highly non-log-concave target distributions, considerable recent effort has been devoted to developing sampling schemes that improve exploration beyond classical diffusion dynamics. A promising line of work leverages tools from information geometry to augment diffusion-based samplers with controlled mass reweighting mechanisms. This perspective leads naturally to Wasserstein--Fisher--Rao (WFR) geometries, which couple transport in the sample space with vertical (reaction) dynamics on the space of probability measures. In this work, we formulate such reweighting mechanisms through the introduction of explicit correction terms and show how they can be implemented via weighted stochastic differential equations using the Feynman--Kac representation. Our study provides a preliminary but rigorous investigation of WFR-based sampling dynamics, and aims to clarify their geometric and operator-theoretic structure as a foundation for future theoretical and algorithmic developments.
