GeONet: a neural operator for learning the Wasserstein geodesic
Andrew Gracyk, Xiaohui Chen
TL;DR
GeONet addresses the computational bottleneck of Wasserstein geodesic computation by learning the non-linear geodesic operator through a mesh-invariant neural operator. It recasts the Benamou–Brenier dynamic OT problem into a coupled primal–dual PDE system and solves it with an enhanced DeepONet architecture that outputs the geodesic ${\mu_t}$ from inputs $(\mu_0, \mu_1)$, enabling real-time predictions and zero-shot super-resolution. The method achieves competitive accuracy to traditional OT solvers on Gaussian mixtures and MNIST-encoded data while delivering substantial inference-time speedups, and it demonstrates robust behavior under distribution shifts and in continuous-to-discrete settings. Limitations include scaling to higher dimensions and the need for carefully chosen input sampling; future work aims at multi-resolution training, theoretical generalization bounds, and extensions to related mean-field planning problems.
Abstract
Optimal transport (OT) offers a versatile framework to compare complex data distributions in a geometrically meaningful way. Traditional methods for computing the Wasserstein distance and geodesic between probability measures require mesh-specific domain discretization and suffer from the curse-of-dimensionality. We present GeONet, a mesh-invariant deep neural operator network that learns the non-linear mapping from the input pair of initial and terminal distributions to the Wasserstein geodesic connecting the two endpoint distributions. In the offline training stage, GeONet learns the saddle point optimality conditions for the dynamic formulation of the OT problem in the primal and dual spaces that are characterized by a coupled PDE system. The subsequent inference stage is instantaneous and can be deployed for real-time predictions in the online learning setting. We demonstrate that GeONet achieves comparable testing accuracy to the standard OT solvers on simulation examples and the MNIST dataset with considerably reduced inference-stage computational cost by orders of magnitude.
