Table of Contents
Fetching ...

Towards Better Spherical Sliced-Wasserstein Distance Learning with Data-Adaptive Discriminative Projection Direction

Hongliang Zhang, Shuo Chen, Lei Luo, Jian Yang

TL;DR

This work addresses the inadequacy of equally weighted projection directions in spherical sliced-Wasserstein distance by introducing DSSW, a data-adaptive, discriminative distance on the sphere. DSSW uses a projected energy function to assign weights to projection directions, with both non-parametric (g(h(·,·))) and parametric (neural-network based h_ψ) incarnations, backed by theoretical guarantees including convergence and Monte Carlo error bounds. Empirical results across gradient flows on the sphere, earth-density estimation with normalizing flows, a Sliced-Wasserstein Autoencoder, and self-supervised learning demonstrate that DSSW consistently outperforms SW, SSW, and S3W baselines, with the parametric variant providing the strongest performance at higher computational cost. The approach offers a practical, scalable improvement for distribution comparison on hyperspheres and can enhance non-Euclidean sliced-Wasserstein methods in broader settings.

Abstract

Spherical Sliced-Wasserstein (SSW) has recently been proposed to measure the discrepancy between spherical data distributions in various fields, such as geology, medical domains, computer vision, and deep representation learning. However, in the original SSW, all projection directions are treated equally, which is too idealistic and cannot accurately reflect the importance of different projection directions for various data distributions. To address this issue, we propose a novel data-adaptive Discriminative Spherical Sliced-Wasserstein (DSSW) distance, which utilizes a projected energy function to determine the discriminative projection direction for SSW. In our new DSSW, we introduce two types of projected energy functions to generate the weights for projection directions with complete theoretical guarantees. The first type employs a non-parametric deterministic function that transforms the projected Wasserstein distance into its corresponding weight in each projection direction. This improves the performance of the original SSW distance with negligible additional computational overhead. The second type utilizes a neural network-induced function that learns the projection direction weight through a parameterized neural network based on data projections. This further enhances the performance of the original SSW distance with less extra computational overhead. Finally, we evaluate the performance of our proposed DSSW by comparing it with several state-of-the-art methods across a variety of machine learning tasks, including gradient flows, density estimation on real earth data, and self-supervised learning.

Towards Better Spherical Sliced-Wasserstein Distance Learning with Data-Adaptive Discriminative Projection Direction

TL;DR

This work addresses the inadequacy of equally weighted projection directions in spherical sliced-Wasserstein distance by introducing DSSW, a data-adaptive, discriminative distance on the sphere. DSSW uses a projected energy function to assign weights to projection directions, with both non-parametric (g(h(·,·))) and parametric (neural-network based h_ψ) incarnations, backed by theoretical guarantees including convergence and Monte Carlo error bounds. Empirical results across gradient flows on the sphere, earth-density estimation with normalizing flows, a Sliced-Wasserstein Autoencoder, and self-supervised learning demonstrate that DSSW consistently outperforms SW, SSW, and S3W baselines, with the parametric variant providing the strongest performance at higher computational cost. The approach offers a practical, scalable improvement for distribution comparison on hyperspheres and can enhance non-Euclidean sliced-Wasserstein methods in broader settings.

Abstract

Spherical Sliced-Wasserstein (SSW) has recently been proposed to measure the discrepancy between spherical data distributions in various fields, such as geology, medical domains, computer vision, and deep representation learning. However, in the original SSW, all projection directions are treated equally, which is too idealistic and cannot accurately reflect the importance of different projection directions for various data distributions. To address this issue, we propose a novel data-adaptive Discriminative Spherical Sliced-Wasserstein (DSSW) distance, which utilizes a projected energy function to determine the discriminative projection direction for SSW. In our new DSSW, we introduce two types of projected energy functions to generate the weights for projection directions with complete theoretical guarantees. The first type employs a non-parametric deterministic function that transforms the projected Wasserstein distance into its corresponding weight in each projection direction. This improves the performance of the original SSW distance with negligible additional computational overhead. The second type utilizes a neural network-induced function that learns the projection direction weight through a parameterized neural network based on data projections. This further enhances the performance of the original SSW distance with less extra computational overhead. Finally, we evaluate the performance of our proposed DSSW by comparing it with several state-of-the-art methods across a variety of machine learning tasks, including gradient flows, density estimation on real earth data, and self-supervised learning.
Paper Structure (35 sections, 8 theorems, 25 equations, 38 figures, 8 tables)

This paper contains 35 sections, 8 theorems, 25 equations, 38 figures, 8 tables.

Key Result

Proposition 1

For any $p \ge 1$ and the projected energy function $f$, the DSSW distance $DSSW_p$ is positive and symmetric.

Figures (38)

  • Figure 1: Runtime comparison for Wasserstein distance, Sinkhorn distance with geodesic distance as cost function, $SW_2$ (SW distance), $SSW_1$ distance with the level median, $SSW_2$ distance with binary search (BS), $SSW_2$ distance against a uniform distribution (Unif), $S3W_2$ distance, $RI$-$S3W_2$ (rotationally invariant extension of $S3W_2$) distance, $ARI$-$S3W_2$ (amortized rotationally invariant extension of $S3W_2$) distance, $DSSW_1$ (exp) (ours), $DSSW_2$ (exp), BS (ours), $DSSW_2$ (exp), Unif (ours).
  • Figure 2: The Mollweide projections for mini-batch projected gradient descent. We use 1, 5, and 30 rotations for RI-S3W (1), RI-S3W (5), and RI-S3W (10), respectively. We also use 30 rotations with a pool size of 1000 for ARI-S3W (30).
  • Figure 3: Projected features on $\mathbb{S}^{2}$ for CIFAR10
  • Figure A4: The network architecture of $h_{\psi }$ used in DSSW. $\hat{X}$ and $\hat{Y}$ denote the projection of $m$ samples from the source distribution $\mu$ with size of $L \times m$ and the projection of $n$ samples from the target distribution $\nu$ with size of $L \times n$, respectively. Notation "$\oplus$" indicates the concatenation operation, notation "$\otimes$" denotes the matrix multiplication.
  • Figure A5: Runtime comparison for Wasserstein distance, Sinkhorn distance with geodesic distance as cost function, $SW_2$ (SW distance), $SSW_1$ distance with the level median, $SSW_2$ distance with binary search (BS), $SSW_2$ distance against a uniform distribution (Unif), $S3W_2$ distance, $RI$-$S3W_2$ distance, $ARI$-$S3W_2$ distance, $DSSW_1$ (identity/poly/linear/nonlinear/attention) (ours), $DSSW_2$ (identity/poly/linear/nonlinear/attention), BS (ours), $DSSW_2$ (identity/poly/linear/nonlinear/attention), Unif (ours).
  • ...and 33 more figures

Theorems & Definitions (10)

  • Definition 1: DSSW Distance
  • Definition 2: Projected Energy Function
  • Proposition 1
  • Proposition 2
  • Proposition 3
  • Theorem 1
  • Proposition 4
  • Proposition 5
  • Proposition 6
  • Theorem 2