Table of Contents
Fetching ...

Wasserstein Flow Matching: Generative modeling over families of distributions

Doron Haviv, Aram-Alexandre Pooladian, Dana Pe'er, Brandon Amos

TL;DR

This work extends Flow Matching to operate over the space of distributions, enabling generative modeling between distributions rather than between points. By leveraging the $W_2$ (Wasserstein) geometry, WFM handles both Gaussian distributions via the Bures–Wasserstein submanifold with closed-form OT maps and general distributions via entropic OT on point-clouds, using transformers to learn conditional vector fields that respect Wasserstein geodesics. It provides a simulation-free training paradigm and demonstrates state-of-the-art-like performance in generating both 2D/3D shapes and high-dimensional cellular microenvironments from spatial transcriptomics data. The approach broadens the scope of generative modeling to domains where samples are distributions themselves, with potential impact in computational biology and graphics, while recognizing the need for careful interpretation of synthetic data in scientific contexts.

Abstract

Generative modeling typically concerns transporting a single source distribution to a target distribution via simple probability flows. However, in fields like computer graphics and single-cell genomics, samples themselves can be viewed as distributions, where standard flow matching ignores their inherent geometry. We propose Wasserstein flow matching (WFM), which lifts flow matching onto families of distributions using the Wasserstein geometry. Notably, WFM is the first algorithm capable of generating distributions in high dimensions, whether represented analytically (as Gaussians) or empirically (as point-clouds). Our theoretical analysis establishes that Wasserstein geodesics constitute proper conditional flows over the space of distributions, making for a valid FM objective. Our algorithm leverages optimal transport theory and the attention mechanism, demonstrating versatility across computational regimes: exploiting closed-form optimal transport paths for Gaussian families, while using entropic estimates on point-clouds for general distributions. WFM successfully generates both 2D & 3D shapes and high-dimensional cellular microenvironments from spatial transcriptomics data. Code is available at https://github.com/DoronHav/WassersteinFlowMatching .

Wasserstein Flow Matching: Generative modeling over families of distributions

TL;DR

This work extends Flow Matching to operate over the space of distributions, enabling generative modeling between distributions rather than between points. By leveraging the (Wasserstein) geometry, WFM handles both Gaussian distributions via the Bures–Wasserstein submanifold with closed-form OT maps and general distributions via entropic OT on point-clouds, using transformers to learn conditional vector fields that respect Wasserstein geodesics. It provides a simulation-free training paradigm and demonstrates state-of-the-art-like performance in generating both 2D/3D shapes and high-dimensional cellular microenvironments from spatial transcriptomics data. The approach broadens the scope of generative modeling to domains where samples are distributions themselves, with potential impact in computational biology and graphics, while recognizing the need for careful interpretation of synthetic data in scientific contexts.

Abstract

Generative modeling typically concerns transporting a single source distribution to a target distribution via simple probability flows. However, in fields like computer graphics and single-cell genomics, samples themselves can be viewed as distributions, where standard flow matching ignores their inherent geometry. We propose Wasserstein flow matching (WFM), which lifts flow matching onto families of distributions using the Wasserstein geometry. Notably, WFM is the first algorithm capable of generating distributions in high dimensions, whether represented analytically (as Gaussians) or empirically (as point-clouds). Our theoretical analysis establishes that Wasserstein geodesics constitute proper conditional flows over the space of distributions, making for a valid FM objective. Our algorithm leverages optimal transport theory and the attention mechanism, demonstrating versatility across computational regimes: exploiting closed-form optimal transport paths for Gaussian families, while using entropic estimates on point-clouds for general distributions. WFM successfully generates both 2D & 3D shapes and high-dimensional cellular microenvironments from spatial transcriptomics data. Code is available at https://github.com/DoronHav/WassersteinFlowMatching .

Paper Structure

This paper contains 38 sections, 7 theorems, 49 equations, 12 figures, 4 tables, 3 algorithms.

Key Result

Proposition 3.1

Conditional probability paths $\mathfrak p_t(\cdot|\nu)$ generated by conditional vector fields of the form $v_t$ (recall (eq:vectorfield)) satisfy $\mathfrak p_1(\cdot|\nu) = \nu$.

Figures (12)

  • Figure 1: WFM learns flows between distributions over distributions, where distributions are either analytically represented (Gaussians) or empirically observed through point-clouds.
  • Figure 2: In the presence of sufficiently many samples, all methods generate Gaussians along the whole spiral, and our Riemannian BW-FM algorithm produces the most consistent samples. Other methods produce Gaussians with degenerate covariance, as they do not model geometry of the data. When there are only few examples, BW-FM accurately reconstructs the training data; see \ref{['fig:bw_spirals_8']}.
  • Figure 3: Left. Synthesized samples from WFM trained on the cars, planes or chairs datasets. Right. Examples generated conditionally from the same initial noise via a WFM model trained on the complete 40-class ModelNet dataset.
  • Figure 4: Generated distributions from MNIST and Letters datasets using WFM. Each distribution is realized as a point-cloud with naturally varying sample sizes ($n$), demonstrating WFM's unique ability to learn from and generate distributions without requiring fixed-size realizations.
  • Figure 5: WFM enables high-dimensional generation of cellular microenvironments. (A) Wasserstein 2D embedding of observed Sst-neuron niches in motor cortex, colored by subtype. (B) WFM generated niches faithfuly reproduce of the tissue landscape
  • ...and 7 more figures

Theorems & Definitions (11)

  • Proposition 3.1
  • Proposition 1.1
  • proof
  • Theorem 1.2
  • Corollary 1.3
  • Lemma 2.1
  • proof
  • Lemma 2.2
  • proof
  • Corollary 2.3
  • ...and 1 more