Table of Contents
Fetching ...

GeONet: a neural operator for learning the Wasserstein geodesic

Andrew Gracyk, Xiaohui Chen

TL;DR

GeONet addresses the computational bottleneck of Wasserstein geodesic computation by learning the non-linear geodesic operator through a mesh-invariant neural operator. It recasts the Benamou–Brenier dynamic OT problem into a coupled primal–dual PDE system and solves it with an enhanced DeepONet architecture that outputs the geodesic ${\mu_t}$ from inputs $(\mu_0, \mu_1)$, enabling real-time predictions and zero-shot super-resolution. The method achieves competitive accuracy to traditional OT solvers on Gaussian mixtures and MNIST-encoded data while delivering substantial inference-time speedups, and it demonstrates robust behavior under distribution shifts and in continuous-to-discrete settings. Limitations include scaling to higher dimensions and the need for carefully chosen input sampling; future work aims at multi-resolution training, theoretical generalization bounds, and extensions to related mean-field planning problems.

Abstract

Optimal transport (OT) offers a versatile framework to compare complex data distributions in a geometrically meaningful way. Traditional methods for computing the Wasserstein distance and geodesic between probability measures require mesh-specific domain discretization and suffer from the curse-of-dimensionality. We present GeONet, a mesh-invariant deep neural operator network that learns the non-linear mapping from the input pair of initial and terminal distributions to the Wasserstein geodesic connecting the two endpoint distributions. In the offline training stage, GeONet learns the saddle point optimality conditions for the dynamic formulation of the OT problem in the primal and dual spaces that are characterized by a coupled PDE system. The subsequent inference stage is instantaneous and can be deployed for real-time predictions in the online learning setting. We demonstrate that GeONet achieves comparable testing accuracy to the standard OT solvers on simulation examples and the MNIST dataset with considerably reduced inference-stage computational cost by orders of magnitude.

GeONet: a neural operator for learning the Wasserstein geodesic

TL;DR

GeONet addresses the computational bottleneck of Wasserstein geodesic computation by learning the non-linear geodesic operator through a mesh-invariant neural operator. It recasts the Benamou–Brenier dynamic OT problem into a coupled primal–dual PDE system and solves it with an enhanced DeepONet architecture that outputs the geodesic from inputs , enabling real-time predictions and zero-shot super-resolution. The method achieves competitive accuracy to traditional OT solvers on Gaussian mixtures and MNIST-encoded data while delivering substantial inference-time speedups, and it demonstrates robust behavior under distribution shifts and in continuous-to-discrete settings. Limitations include scaling to higher dimensions and the need for carefully chosen input sampling; future work aims at multi-resolution training, theoretical generalization bounds, and extensions to related mean-field planning problems.

Abstract

Optimal transport (OT) offers a versatile framework to compare complex data distributions in a geometrically meaningful way. Traditional methods for computing the Wasserstein distance and geodesic between probability measures require mesh-specific domain discretization and suffer from the curse-of-dimensionality. We present GeONet, a mesh-invariant deep neural operator network that learns the non-linear mapping from the input pair of initial and terminal distributions to the Wasserstein geodesic connecting the two endpoint distributions. In the offline training stage, GeONet learns the saddle point optimality conditions for the dynamic formulation of the OT problem in the primal and dual spaces that are characterized by a coupled PDE system. The subsequent inference stage is instantaneous and can be deployed for real-time predictions in the online learning setting. We demonstrate that GeONet achieves comparable testing accuracy to the standard OT solvers on simulation examples and the MNIST dataset with considerably reduced inference-stage computational cost by orders of magnitude.
Paper Structure (30 sections, 6 theorems, 47 equations, 10 figures, 7 tables, 1 algorithm)

This paper contains 30 sections, 6 theorems, 47 equations, 10 figures, 7 tables, 1 algorithm.

Key Result

Theorem C.3

If $\omega : [0,1] \to X$ is Lipschitz continuous, then the metric derivative $|\omega'|(t)$ exists for almost every $t \in [0,1]$. In addition, for any $0 \leqslant t < s \leqslant 1$, we have

Figures (10)

  • Figure 1: A geodesic at different spatial resolutions. Low-resolution inputs can be adapted into high-resolution geodesics (i.e., super-resolution) with our output mesh-invariant GeONet method.
  • Figure 2: Architecture of GeONet, containing six neural networks to solve the continuity and Hamilton-Jacobi (HJ) equations, three for each. We minimize the total loss, and the continuity solution yields the geodesic. GeONet branches and trunks output vectors of dimension $p$, in which we perform multiplication among neural network elements to produce the continuity and HJ solutions.
  • Figure 3: Four geodesics predicted by GeONet with reference geodesics computed by POT on test univariate Gaussian mixture distribution pairs with $k_0 = k_1 = 6$. The reference serves as a close approximation to the true geodesic. The vertical axis is space and the horizontal axis is time.
  • Figure 4: We compare to GeONet to the alternative methodology in a discrete setting, using POT as ground truth. GeONet is the only method among the comparison which encapsulates the geodesic behaviour among the translocation of points.
  • Figure 5: Beginning from top left and going clockwise, we display the initial conditions in the encoded space, the geodesics in the encoded space, and the decoded geodesics as $28 \times 28$ images. (a) and (b) correspond to two unique pairings.
  • ...and 5 more figures

Theorems & Definitions (12)

  • Definition C.1: Absolutely continuous curve
  • Definition C.2: Metric derivative
  • Theorem C.3: Rademacher
  • Definition C.4: Curve length
  • Lemma C.5
  • Definition C.6: Length space and geodesic space
  • Definition C.7: Geodesic
  • Lemma C.8
  • Definition C.9: Wasserstein metric derivative
  • Theorem C.10
  • ...and 2 more