Wasserstein Wormhole: Scalable Optimal Transport Distance with Transformers
Doron Haviv, Russell Zhang Kunes, Thomas Dougherty, Cassandra Burdziak, Tal Nawy, Anna Gilbert, Dana Pe'er
TL;DR
The paper introduces Wasserstein Wormhole, a transformer-based autoencoder that embeds variable-size point clouds into a Euclidean latent space where pairwise Euclidean distances approximate Wasserstein distances, enabling scalable OT computations. It develops a theoretical framework for embedding non-Euclidean distance matrices by deriving a lower bound and an upper bound on embedding error, and proposes a projected gradient descent approach that is guaranteed to converge to the global optimum for the closest EDM. Empirically, Wormhole achieves OT-faithful embeddings across MNIST, Fashion-MNIST, ModelNet40, ShapeNet, and high-dimensional cellular niches, while enabling barycenters and interpolation via its decoder and scaling to large cohorts with minimal OT calls. The method shows strong OT fidelity, supports GW extensions, preserves rotational invariance in GW mode, and offers practical utility for downstream tasks like clustering, visualization, and high-dimensional OT analyses in computational geometry and biology. The work emphasizes scalability, interpretability through the decoder, and theoretical guarantees for embedding non-Euclidean OT distances into Euclidean space.
Abstract
Optimal transport (OT) and the related Wasserstein metric (W) are powerful and ubiquitous tools for comparing distributions. However, computing pairwise Wasserstein distances rapidly becomes intractable as cohort size grows. An attractive alternative would be to find an embedding space in which pairwise Euclidean distances map to OT distances, akin to standard multidimensional scaling (MDS). We present Wasserstein Wormhole, a transformer-based autoencoder that embeds empirical distributions into a latent space wherein Euclidean distances approximate OT distances. Extending MDS theory, we show that our objective function implies a bound on the error incurred when embedding non-Euclidean distances. Empirically, distances between Wormhole embeddings closely match Wasserstein distances, enabling linear time computation of OT distances. Along with an encoder that maps distributions to embeddings, Wasserstein Wormhole includes a decoder that maps embeddings back to distributions, allowing for operations in the embedding space to generalize to OT spaces, such as Wasserstein barycenter estimation and OT interpolation. By lending scalability and interpretability to OT approaches, Wasserstein Wormhole unlocks new avenues for data analysis in the fields of computational geometry and single-cell biology.
