Table of Contents
Fetching ...

Wasserstein Wormhole: Scalable Optimal Transport Distance with Transformers

Doron Haviv, Russell Zhang Kunes, Thomas Dougherty, Cassandra Burdziak, Tal Nawy, Anna Gilbert, Dana Pe'er

TL;DR

The paper introduces Wasserstein Wormhole, a transformer-based autoencoder that embeds variable-size point clouds into a Euclidean latent space where pairwise Euclidean distances approximate Wasserstein distances, enabling scalable OT computations. It develops a theoretical framework for embedding non-Euclidean distance matrices by deriving a lower bound and an upper bound on embedding error, and proposes a projected gradient descent approach that is guaranteed to converge to the global optimum for the closest EDM. Empirically, Wormhole achieves OT-faithful embeddings across MNIST, Fashion-MNIST, ModelNet40, ShapeNet, and high-dimensional cellular niches, while enabling barycenters and interpolation via its decoder and scaling to large cohorts with minimal OT calls. The method shows strong OT fidelity, supports GW extensions, preserves rotational invariance in GW mode, and offers practical utility for downstream tasks like clustering, visualization, and high-dimensional OT analyses in computational geometry and biology. The work emphasizes scalability, interpretability through the decoder, and theoretical guarantees for embedding non-Euclidean OT distances into Euclidean space.

Abstract

Optimal transport (OT) and the related Wasserstein metric (W) are powerful and ubiquitous tools for comparing distributions. However, computing pairwise Wasserstein distances rapidly becomes intractable as cohort size grows. An attractive alternative would be to find an embedding space in which pairwise Euclidean distances map to OT distances, akin to standard multidimensional scaling (MDS). We present Wasserstein Wormhole, a transformer-based autoencoder that embeds empirical distributions into a latent space wherein Euclidean distances approximate OT distances. Extending MDS theory, we show that our objective function implies a bound on the error incurred when embedding non-Euclidean distances. Empirically, distances between Wormhole embeddings closely match Wasserstein distances, enabling linear time computation of OT distances. Along with an encoder that maps distributions to embeddings, Wasserstein Wormhole includes a decoder that maps embeddings back to distributions, allowing for operations in the embedding space to generalize to OT spaces, such as Wasserstein barycenter estimation and OT interpolation. By lending scalability and interpretability to OT approaches, Wasserstein Wormhole unlocks new avenues for data analysis in the fields of computational geometry and single-cell biology.

Wasserstein Wormhole: Scalable Optimal Transport Distance with Transformers

TL;DR

The paper introduces Wasserstein Wormhole, a transformer-based autoencoder that embeds variable-size point clouds into a Euclidean latent space where pairwise Euclidean distances approximate Wasserstein distances, enabling scalable OT computations. It develops a theoretical framework for embedding non-Euclidean distance matrices by deriving a lower bound and an upper bound on embedding error, and proposes a projected gradient descent approach that is guaranteed to converge to the global optimum for the closest EDM. Empirically, Wormhole achieves OT-faithful embeddings across MNIST, Fashion-MNIST, ModelNet40, ShapeNet, and high-dimensional cellular niches, while enabling barycenters and interpolation via its decoder and scaling to large cohorts with minimal OT calls. The method shows strong OT fidelity, supports GW extensions, preserves rotational invariance in GW mode, and offers practical utility for downstream tasks like clustering, visualization, and high-dimensional OT analyses in computational geometry and biology. The work emphasizes scalability, interpretability through the decoder, and theoretical guarantees for embedding non-Euclidean OT distances into Euclidean space.

Abstract

Optimal transport (OT) and the related Wasserstein metric (W) are powerful and ubiquitous tools for comparing distributions. However, computing pairwise Wasserstein distances rapidly becomes intractable as cohort size grows. An attractive alternative would be to find an embedding space in which pairwise Euclidean distances map to OT distances, akin to standard multidimensional scaling (MDS). We present Wasserstein Wormhole, a transformer-based autoencoder that embeds empirical distributions into a latent space wherein Euclidean distances approximate OT distances. Extending MDS theory, we show that our objective function implies a bound on the error incurred when embedding non-Euclidean distances. Empirically, distances between Wormhole embeddings closely match Wasserstein distances, enabling linear time computation of OT distances. Along with an encoder that maps distributions to embeddings, Wasserstein Wormhole includes a decoder that maps embeddings back to distributions, allowing for operations in the embedding space to generalize to OT spaces, such as Wasserstein barycenter estimation and OT interpolation. By lending scalability and interpretability to OT approaches, Wasserstein Wormhole unlocks new avenues for data analysis in the fields of computational geometry and single-cell biology.
Paper Structure (29 sections, 4 theorems, 56 equations, 9 figures, 3 tables, 2 algorithms)

This paper contains 29 sections, 4 theorems, 56 equations, 9 figures, 3 tables, 2 algorithms.

Key Result

Theorem 5.1

For a given distance matrix $D$ and the eigendecomposition $\{\lambda_{i}, v_{i}\}_{i=1}^{N}$ of its criterion matrix $F=-JDJ$, the optimal stress is greater or equal to:

Figures (9)

  • Figure 1: Schematic of Wasserstein Wormhole. Empirical distributions (point clouds) are passed through a transformer to produce per point-cloud vector embeddings such that the Euclidean distance between embeddings match the pairwise Wasserstein distance between point clouds. Since computation of OT distances is laborious, Wormhole is optimized by mini-batches to minimize the discrepancy between the embedding pairwise distances and the pairwise Wasserstein distances of the batch point clouds. The Wormhole decoder (not shown) is a second transformer trained to reproduce the input point clouds from the embedding by minimizing the OT distance between input and output.
  • Figure 2: Benchmarking run-time of Wormhole against other OT acceleration algorithms. Sampling cohorts of different sizes from the MNIST dataset, we measured the time required for current acceleration algorithms (with GPU implementation in JAX-OTT) and Wormhole to compute or approximate the pairwise Wasserstein matrix. Other methods are more appropriate at tiny cohorts, as they do not require training a parametric model. However, even in cohorts with relatively few samples, Wormhole is superior and no other method can scale to complete datasets, requiring weeks of compute time, even on a fully-utilized 80GB GPU.
  • Figure 3: Wormhole on MNIST point clouds. a. Example of a point cloud produced by thresholding MNIST images. b. Retrieval of the Wasserstein distance in Wormhole embedding space. From a random sample of $128$ point clouds from the test set, we denote the correlation between the true pairwise OT distance and the Euclidean distance between their embeddings. The line $y=x$ is drawn in blue for reference. c. 2D UMAP visualization of Wormhole embeddings of the training and test set point clouds, which recapitulates ground-truth image digit without utilizing label information during training.
  • Figure 4: Generalization to OOD on Fashion-MNIST. During training, we held out out point clouds labeled 'Bag'. Performance was slightly lower compared to observed classes, especially for larger Wasserstein distances, as the MSE for OOD samples was $4.38\cdot10^{-5}$ as opposed to $3.37\cdot10^{-5}$ for observed samples. While not trained on any 'Bag' point cloud, their encodings still largely agree with true OT distance, producing a Pearson correlation of $\rho=0.995$ and label accuracy of $0.97$.
  • Figure 5: Wormhole Variational Wasserstein. a. Mean encodings of each class were passed through the decoder to visualize their Wasserstein barycenters. Working in Wormhole space provides a quick and straightforward proxy for OT analysis, and recovered decodings reproduce the underlying class well. For brevity, we show 10 out of the 40 classes in from ModelNet40. b. Wormhole interpolation between two point clouds within the same class (top) and across different classes (bottom). Intermediate point clouds are decodings of linear interpolations between embeddings.
  • ...and 4 more figures

Theorems & Definitions (7)

  • Definition 3.1: Euclidean distance matrix (EDM)
  • Theorem 5.1: Lower Bound
  • Theorem 5.2: Upper Bound
  • Lemma 9.1
  • proof
  • Lemma 9.2
  • proof