Table of Contents
Fetching ...

Fast Estimation of Wasserstein Distances via Regression on Sliced Wasserstein Distances

Khai Nguyen, Hai Nguyen, Nhat Ho

TL;DR

This work proposes a fast estimation method based on regressing Wasserstein distance on sliced Wasserstein (SW) distances that consistently provides a better approximation of Wasserstein distance than the state-of-the-art Wasserstein embedding model, Wasserstein Wormhole, particularly in low-data regimes.

Abstract

We address the problem of efficiently computing Wasserstein distances for multiple pairs of distributions drawn from a meta-distribution. To this end, we propose a fast estimation method based on regressing Wasserstein distance on sliced Wasserstein (SW) distances. Specifically, we leverage both standard SW distances, which provide lower bounds, and lifted SW distances, which provide upper bounds, as predictors of the true Wasserstein distance. To ensure parsimony, we introduce two linear models: an unconstrained model with a closed-form least-squares solution, and a constrained model that uses only half as many parameters. We show that accurate models can be learned from a small number of distribution pairs. Once estimated, the model can predict the Wasserstein distance for any pair of distributions via a linear combination of SW distances, making it highly efficient. Empirically, we validate our approach on diverse tasks, including Gaussian mixtures, point-cloud classification, and Wasserstein-space visualizations for 3D point clouds. Across various datasets such as MNIST point clouds, ShapeNetV2, MERFISH Cell Niches, and scRNA-seq, our method consistently provides a better approximation of Wasserstein distance than the state-of-the-art Wasserstein embedding model, Wasserstein Wormhole, particularly in low-data regimes. Finally, we demonstrate that our estimator can also accelerate Wormhole training, yielding \textit{RG-Wormhole}.

Fast Estimation of Wasserstein Distances via Regression on Sliced Wasserstein Distances

TL;DR

This work proposes a fast estimation method based on regressing Wasserstein distance on sliced Wasserstein (SW) distances that consistently provides a better approximation of Wasserstein distance than the state-of-the-art Wasserstein embedding model, Wasserstein Wormhole, particularly in low-data regimes.

Abstract

We address the problem of efficiently computing Wasserstein distances for multiple pairs of distributions drawn from a meta-distribution. To this end, we propose a fast estimation method based on regressing Wasserstein distance on sliced Wasserstein (SW) distances. Specifically, we leverage both standard SW distances, which provide lower bounds, and lifted SW distances, which provide upper bounds, as predictors of the true Wasserstein distance. To ensure parsimony, we introduce two linear models: an unconstrained model with a closed-form least-squares solution, and a constrained model that uses only half as many parameters. We show that accurate models can be learned from a small number of distribution pairs. Once estimated, the model can predict the Wasserstein distance for any pair of distributions via a linear combination of SW distances, making it highly efficient. Empirically, we validate our approach on diverse tasks, including Gaussian mixtures, point-cloud classification, and Wasserstein-space visualizations for 3D point clouds. Across various datasets such as MNIST point clouds, ShapeNetV2, MERFISH Cell Niches, and scRNA-seq, our method consistently provides a better approximation of Wasserstein distance than the state-of-the-art Wasserstein embedding model, Wasserstein Wormhole, particularly in low-data regimes. Finally, we demonstrate that our estimator can also accelerate Wormhole training, yielding \textit{RG-Wormhole}.

Paper Structure

This paper contains 22 sections, 24 equations, 30 figures, 6 tables.

Figures (30)

  • Figure 1: Linear regression of the Wasserstein distance vector $\hat{\boldsymbol{W}}$ on sliced Wasserstein (SW) distances $\hat{\boldsymbol{S}}^{(1)},\ldots,\hat{\boldsymbol{S}}^{(K)}$. The left figure illustrates a linear model, interpreted as the $\mathbb{L}_2$ projection of the Wasserstein distance onto the linear span of the SW distances. The right figure depicts a special case of a constrained linear model with only two SW distances as predictors, which can be seen as a midpoint method.
  • Figure 2: ModelNet40: a RG-Wormhole variant in reconstruction experiment.
  • Figure 3: ModelNet40: a RG-Wormhole variant in interpolation experiment.
  • Figure 4: Optimal $w^*$ and $R^2$ in each dimension.
  • Figure 5: Embeddings of methods in ShapeNetV2 dataset.
  • ...and 25 more figures

Theorems & Definitions (7)

  • Remark 1
  • Remark 2
  • Definition 1: Regression of Wasserstein distance onto SW distances
  • Definition 2: Linear Regression of Wasserstein distance onto SW distances
  • Remark 3
  • Definition 3: Constrained Linear Regression of Wasserstein distance onto SW distances
  • Remark 4