Fast and scalable Wasserstein-1 neural optimal transport solver for single-cell perturbation prediction
Yanshuo Chen, Zhengmian Hu, Wei Chen, Heng Huang
TL;DR
This paper tackles the problem of predicting perturbation responses in unpaired single-cell data by formulating it as a Wasserstein-1 OT task. It introduces a two-step solver that first learns a transport direction via maximizing the Kantorovich dual over a 1‑Lipschitz potential $f$, then learns a sample-specific transport step size through adversarial training to realize the map $T(x)=x-\eta(x)\nabla f(x)$. Empirically, the method learns monotonic transport on 2D toy data, matches or surpasses $W_2$ OT performance on real scRNA-seq and imaging datasets, and achieves a substantial speedup (roughly $25$–$45\times$) with improved scalability to high-dimensional inputs. The approach enables fast, scalable distribution alignment for single-cell perturbation prediction and can be extended to conditional transport and other high-dimensional biological applications, offering a practical alternative to existing $W_2$-based OT solvers.
Abstract
\textbf{Motivation:} Predicting single-cell perturbation responses requires mapping between two unpaired single-cell data distributions. Optimal transport (OT) theory provides a principled framework for constructing such mappings by minimizing transport cost. Recently, Wasserstein-2 ($W_2$) neural optimal transport solvers (\textit{e.g.}, CellOT) have been employed for this prediction task. However, $W_2$ OT relies on the general Kantorovich dual formulation, which involves optimizing over two conjugate functions, leading to a complex min-max optimization problem that converges slowly. \\ \textbf{Results:} To address these challenges, we propose a novel solver based on the Wasserstein-1 ($W_1$) dual formulation. Unlike $W_2$, the $W_1$ dual simplifies the optimization to a maximization problem over a single 1-Lipschitz function, thus eliminating the need for time-consuming min-max optimization. While solving the $W_1$ dual only reveals the transport direction and does not directly provide a unique optimal transport map, we incorporate an additional step using adversarial training to determine an appropriate transport step size, effectively recovering the transport map. Our experiments demonstrate that the proposed $W_1$ neural optimal transport solver can mimic the $W_2$ OT solvers in finding a unique and ``monotonic" map on 2D datasets. Moreover, the $W_1$ OT solver achieves performance on par with or surpasses $W_2$ OT solvers on real single-cell perturbation datasets. Furthermore, we show that $W_1$ OT solver achieves $25 \sim 45\times$ speedup, scales better on high dimensional transportation task, and can be directly applied on single-cell RNA-seq dataset with highly variable genes. \\ \textbf{Availability and Implementation:} Our implementation and experiments are open-sourced at https://github.com/poseidonchan/w1ot.
