Fourier Sliced-Wasserstein Embedding for Multisets and Measures
Tal Amir, Nadav Dym
TL;DR
The Fourier Sliced-Wasserstein embedding (FSW) provides a provably injective, near-bi-Lipschitz Euclidean embedding for finite multisets and distributions in $\mathbb{R}^d$, enabling efficient learning on non-Euclidean data. By projecting inputs onto random 1D directions and sampling the Fourier transform of their quantile functions, FSW exactly captures the sliced-Wasserstein geometry in expectation and offers strong finite-dimensional guarantees: $m \ge 2Nd+1$ suffices for multisets and $m \ge 2Nd+2N-1$ for distributions. The embedding is computationally efficient, with complexity $\mathcal{O}(mNd + mN\log N)$, and it outperforms prior methods such as PSWE in preserving distances and supporting learning tasks; it also improves the robustness of PointNet when max-pooling is replaced by FSW. However, a bi-Lipschitz Euclidean embedding of all finite distributions is provably impossible, underscoring that FSW achieves near-optimal metric guarantees within finite dimensions and practical constraints. The work demonstrates meaningful improvements in learning tasks involving multisets and paves the way for integrating FSW into graph neural networks and broader transport-based metrics.
Abstract
We present the Fourier Sliced-Wasserstein (FSW) embedding - a novel method to embed multisets and measures over R^d into Euclidean space. Our proposed embedding approximately preserves the sliced Wasserstein distance on distributions, thereby yielding geometrically meaningful representations that better capture the structure of the input. Moreover, it is injective on measures and bi-Lipschitz on multisets - a significant advantage over prevalent methods based on sum- or max-pooling, which are provably not bi-Lipschitz, and, in many cases, not even injective. The required output dimension for these guarantees is near-optimal: roughly 2Nd, where N is the maximal input multiset size. Furthermore, we prove that it is impossible to embed distributions over R^d into Euclidean space in a bi-Lipschitz manner. Thus, the metric properties of our embedding are, in a sense, the best possible. Through numerical experiments, we demonstrate that our method yields superior multiset representations that improve performance in practical learning tasks. Specifically, we show that (a) a simple combination of the FSW embedding with an MLP achieves state-of-the-art performance in learning the (non-sliced) Wasserstein distance; and (b) replacing max-pooling with the FSW embedding makes PointNet significantly more robust to parameter reduction, with only minor performance degradation even after a 40-fold reduction.
