Wasserstein Flow Matching: Generative modeling over families of distributions
Doron Haviv, Aram-Alexandre Pooladian, Dana Pe'er, Brandon Amos
TL;DR
This work extends Flow Matching to operate over the space of distributions, enabling generative modeling between distributions rather than between points. By leveraging the $W_2$ (Wasserstein) geometry, WFM handles both Gaussian distributions via the Bures–Wasserstein submanifold with closed-form OT maps and general distributions via entropic OT on point-clouds, using transformers to learn conditional vector fields that respect Wasserstein geodesics. It provides a simulation-free training paradigm and demonstrates state-of-the-art-like performance in generating both 2D/3D shapes and high-dimensional cellular microenvironments from spatial transcriptomics data. The approach broadens the scope of generative modeling to domains where samples are distributions themselves, with potential impact in computational biology and graphics, while recognizing the need for careful interpretation of synthetic data in scientific contexts.
Abstract
Generative modeling typically concerns transporting a single source distribution to a target distribution via simple probability flows. However, in fields like computer graphics and single-cell genomics, samples themselves can be viewed as distributions, where standard flow matching ignores their inherent geometry. We propose Wasserstein flow matching (WFM), which lifts flow matching onto families of distributions using the Wasserstein geometry. Notably, WFM is the first algorithm capable of generating distributions in high dimensions, whether represented analytically (as Gaussians) or empirically (as point-clouds). Our theoretical analysis establishes that Wasserstein geodesics constitute proper conditional flows over the space of distributions, making for a valid FM objective. Our algorithm leverages optimal transport theory and the attention mechanism, demonstrating versatility across computational regimes: exploiting closed-form optimal transport paths for Gaussian families, while using entropic estimates on point-clouds for general distributions. WFM successfully generates both 2D & 3D shapes and high-dimensional cellular microenvironments from spatial transcriptomics data. Code is available at https://github.com/DoronHav/WassersteinFlowMatching .
