Point Cloud Classification via Deep Set Linearized Optimal Transport
Scott Mahan, Caroline Moosmüller, Alexander Cloninger
TL;DR
This work introduces Deep Set Linearized Optimal Transport (DSLOT), which embeds point clouds or probability measures into an $L^2$ space via the LOT map relative to a fixed reference distribution. By learning Brenier potentials with Input-Convex Neural Networks and taking gradients, the approach yields a near-isometric, permutation-invariant representation that supports efficient classification, even when classes are formed by shifts or scalings. The authors provide theoretical guarantees linking ICNN-based OT map approximations to Wasserstein-2 distances with high probability and show empirical improvements over DeepSets on a flow cytometry AML dataset, particularly when employing resampling from the reference measure. The method offers a scalable, geometry-aware framework for point-cloud classification that leverages generative OT maps and a learned, invariant classifier, with potential applicability to high-dimensional distributional data in biology and beyond.
Abstract
We introduce Deep Set Linearized Optimal Transport, an algorithm designed for the efficient simultaneous embedding of point clouds into an $L^2-$space. This embedding preserves specific low-dimensional structures within the Wasserstein space while constructing a classifier to distinguish between various classes of point clouds. Our approach is motivated by the observation that $L^2-$distances between optimal transport maps for distinct point clouds, originating from a shared fixed reference distribution, provide an approximation of the Wasserstein-2 distance between these point clouds, under certain assumptions. To learn approximations of these transport maps, we employ input convex neural networks (ICNNs) and establish that, under specific conditions, Euclidean distances between samples from these ICNNs closely mirror Wasserstein-2 distances between the true distributions. Additionally, we train a discriminator network that attaches weights these samples and creates a permutation invariant classifier to differentiate between different classes of point clouds. We showcase the advantages of our algorithm over the standard deep set approach through experiments on a flow cytometry dataset with a limited number of labeled point clouds.
