Table of Contents
Fetching ...

Neural solver for Wasserstein Geodesics and optimal transport dynamics

Hailiang Liu, Yan-Han Chen

TL;DR

This work introduces a sample-based neural solver for computing the Wasserstein geodesic between a source and target distribution, along with the associated velocity field, using deep neural networks to approximate the relevant functions.

Abstract

In recent years, the machine learning community has increasingly embraced the optimal transport (OT) framework for modeling distributional relationships. In this work, we introduce a sample-based neural solver for computing the Wasserstein geodesic between a source and target distribution, along with the associated velocity field. Building on the dynamical formulation of the optimal transport (OT) problem, we recast the constrained optimization as a minimax problem, using deep neural networks to approximate the relevant functions. This approach not only provides the Wasserstein geodesic but also recovers the OT map, enabling direct sampling from the target distribution. By estimating the OT map, we obtain velocity estimates along particle trajectories, which in turn allow us to learn the full velocity field. The framework is flexible and readily extends to general cost functions, including the commonly used quadratic cost. We demonstrate the effectiveness of our method through experiments on both synthetic and real datasets.

Neural solver for Wasserstein Geodesics and optimal transport dynamics

TL;DR

This work introduces a sample-based neural solver for computing the Wasserstein geodesic between a source and target distribution, along with the associated velocity field, using deep neural networks to approximate the relevant functions.

Abstract

In recent years, the machine learning community has increasingly embraced the optimal transport (OT) framework for modeling distributional relationships. In this work, we introduce a sample-based neural solver for computing the Wasserstein geodesic between a source and target distribution, along with the associated velocity field. Building on the dynamical formulation of the optimal transport (OT) problem, we recast the constrained optimization as a minimax problem, using deep neural networks to approximate the relevant functions. This approach not only provides the Wasserstein geodesic but also recovers the OT map, enabling direct sampling from the target distribution. By estimating the OT map, we obtain velocity estimates along particle trajectories, which in turn allow us to learn the full velocity field. The framework is flexible and readily extends to general cost functions, including the commonly used quadratic cost. We demonstrate the effectiveness of our method through experiments on both synthetic and real datasets.
Paper Structure (22 sections, 1 theorem, 44 equations, 20 figures, 2 algorithms)

This paper contains 22 sections, 1 theorem, 44 equations, 20 figures, 2 algorithms.

Key Result

Theorem 2.2

Suppose the min-max problem (learn_F) admits a unique solution $(F^*,\phi^*)$. Then $F^*$ is also a solution to problem (Problem_W), which is equivalent to the original problem (dOMT_wk).

Figures (20)

  • Figure 1: Phase 1 of Synthetic-1: Samples of the source distribution $\rho_a$ (open circle at the bottom left corner) are transported to the upper right corner. Each gray line illustrates the learned trajectory $G_\theta(t;\cdot)_\#\rho_a$. The contour shows the true log density of $\rho_b$.
  • Figure 2: Phase 2 of Synthetic-1: learned velocity fields at selected time points. The arrows illustrate the direction and magnitude of the learned velocity.
  • Figure 3: Phase 1 of Synthetic-2: Samples of the source distribution $\rho_a$ (open circle at the center) are transported to the four corners. Each gray line illustrates the learned trajectory $G_\theta(t;\cdot)_\#\rho_a$. The contour shows the true log density of $\rho_b$.
  • Figure 4: Phase 2 of Synthetic-2: learned velocity fields at selected time points. The arrows illustrate the direction and magnitude of the learned velocity.
  • Figure 5: Phase 1 of Synthetic-3: 2-dimensional projection of the 10-dimensional distributions. Samples of the source distribution $\rho_a$ (open circle at the bottom left corner) are transported to the upper right corner. Each gray line illustrates the learned trajectory $G_\theta(t;\cdot)_\#\rho_a$. The contour shows the true log density of $\rho_b$.
  • ...and 15 more figures

Theorems & Definitions (4)

  • Remark 2.1
  • Theorem 2.2
  • Remark 2.3
  • Remark 3.1