Table of Contents
Fetching ...

Fast and scalable Wasserstein-1 neural optimal transport solver for single-cell perturbation prediction

Yanshuo Chen, Zhengmian Hu, Wei Chen, Heng Huang

TL;DR

This paper tackles the problem of predicting perturbation responses in unpaired single-cell data by formulating it as a Wasserstein-1 OT task. It introduces a two-step solver that first learns a transport direction via maximizing the Kantorovich dual over a 1‑Lipschitz potential $f$, then learns a sample-specific transport step size through adversarial training to realize the map $T(x)=x-\eta(x)\nabla f(x)$. Empirically, the method learns monotonic transport on 2D toy data, matches or surpasses $W_2$ OT performance on real scRNA-seq and imaging datasets, and achieves a substantial speedup (roughly $25$–$45\times$) with improved scalability to high-dimensional inputs. The approach enables fast, scalable distribution alignment for single-cell perturbation prediction and can be extended to conditional transport and other high-dimensional biological applications, offering a practical alternative to existing $W_2$-based OT solvers.

Abstract

\textbf{Motivation:} Predicting single-cell perturbation responses requires mapping between two unpaired single-cell data distributions. Optimal transport (OT) theory provides a principled framework for constructing such mappings by minimizing transport cost. Recently, Wasserstein-2 ($W_2$) neural optimal transport solvers (\textit{e.g.}, CellOT) have been employed for this prediction task. However, $W_2$ OT relies on the general Kantorovich dual formulation, which involves optimizing over two conjugate functions, leading to a complex min-max optimization problem that converges slowly. \\ \textbf{Results:} To address these challenges, we propose a novel solver based on the Wasserstein-1 ($W_1$) dual formulation. Unlike $W_2$, the $W_1$ dual simplifies the optimization to a maximization problem over a single 1-Lipschitz function, thus eliminating the need for time-consuming min-max optimization. While solving the $W_1$ dual only reveals the transport direction and does not directly provide a unique optimal transport map, we incorporate an additional step using adversarial training to determine an appropriate transport step size, effectively recovering the transport map. Our experiments demonstrate that the proposed $W_1$ neural optimal transport solver can mimic the $W_2$ OT solvers in finding a unique and ``monotonic" map on 2D datasets. Moreover, the $W_1$ OT solver achieves performance on par with or surpasses $W_2$ OT solvers on real single-cell perturbation datasets. Furthermore, we show that $W_1$ OT solver achieves $25 \sim 45\times$ speedup, scales better on high dimensional transportation task, and can be directly applied on single-cell RNA-seq dataset with highly variable genes. \\ \textbf{Availability and Implementation:} Our implementation and experiments are open-sourced at https://github.com/poseidonchan/w1ot.

Fast and scalable Wasserstein-1 neural optimal transport solver for single-cell perturbation prediction

TL;DR

This paper tackles the problem of predicting perturbation responses in unpaired single-cell data by formulating it as a Wasserstein-1 OT task. It introduces a two-step solver that first learns a transport direction via maximizing the Kantorovich dual over a 1‑Lipschitz potential , then learns a sample-specific transport step size through adversarial training to realize the map . Empirically, the method learns monotonic transport on 2D toy data, matches or surpasses OT performance on real scRNA-seq and imaging datasets, and achieves a substantial speedup (roughly ) with improved scalability to high-dimensional inputs. The approach enables fast, scalable distribution alignment for single-cell perturbation prediction and can be extended to conditional transport and other high-dimensional biological applications, offering a practical alternative to existing -based OT solvers.

Abstract

\textbf{Motivation:} Predicting single-cell perturbation responses requires mapping between two unpaired single-cell data distributions. Optimal transport (OT) theory provides a principled framework for constructing such mappings by minimizing transport cost. Recently, Wasserstein-2 () neural optimal transport solvers (\textit{e.g.}, CellOT) have been employed for this prediction task. However, OT relies on the general Kantorovich dual formulation, which involves optimizing over two conjugate functions, leading to a complex min-max optimization problem that converges slowly. \\ \textbf{Results:} To address these challenges, we propose a novel solver based on the Wasserstein-1 () dual formulation. Unlike , the dual simplifies the optimization to a maximization problem over a single 1-Lipschitz function, thus eliminating the need for time-consuming min-max optimization. While solving the dual only reveals the transport direction and does not directly provide a unique optimal transport map, we incorporate an additional step using adversarial training to determine an appropriate transport step size, effectively recovering the transport map. Our experiments demonstrate that the proposed neural optimal transport solver can mimic the OT solvers in finding a unique and ``monotonic" map on 2D datasets. Moreover, the OT solver achieves performance on par with or surpasses OT solvers on real single-cell perturbation datasets. Furthermore, we show that OT solver achieves speedup, scales better on high dimensional transportation task, and can be directly applied on single-cell RNA-seq dataset with highly variable genes. \\ \textbf{Availability and Implementation:} Our implementation and experiments are open-sourced at https://github.com/poseidonchan/w1ot.

Paper Structure

This paper contains 31 sections, 15 equations, 4 figures.

Figures (4)

  • Figure 1: Overview of $W_1$OTa. The single-cell perturbation prediction task needs a mapping to map the control group ($\mu$) cells to the perturbation group ($\nu$) cells. b. Our proposed $W_1$ OT solver. The first step is to learn the transport direction by maximizing the Kantorovich dual problem. The second step is to learn the appropriate transport step size via adversarial training to finally construct the transport map.Figure conceptually inspired by the illustration in the CellOT paper cellot.
  • Figure 2: $W_1$OT solver on toy datasetsa. The "bookshelf" datasets. To verify the transport map is "monotonic", we set 5 markers: triangle, square, diamond, circle, and cross. The markers keep their original order after transportation. b. The "ciricles" dataset. It consists of 4 concentric circles. The inner/outer source circle is expected to transported to the inner/outer target circle respectively. We can see that there is no transport between inner source circle and outer target circle. c. The "swiss roll" dataset. The target distribution is a 2D swiss roll and the source distribution is a Gaussian. d. The "moons" dataset. The source and target distribution are half circles.
  • Figure 3: Performance benchmark on 2 real single-cell datasets.a. The performance on 4i datasets without dimensionality reduction. Each perturbation is run independently 5 times. The performance is summarized over all perturbations in one dataset. b. The performance on 9 selected drugs on sciplex3 dataset in the i.i.d. setting with autoencoder embeddings and 50 latent dimensions. "observed" denotes the best performance model can achieve. "identity" denotes the lower bound performance. c. The methods' performance in the o.o.d setting, where the test set uses unseen celltype. The arrow head represents the better performance direction. All the statistical tests are Wilcoxon rank-sum test.
  • Figure 4: $W_1$OT solver on high dimensional datasets.a. The scalability of two neural OT solver. It shows the time consumption of 10,000 iterations with different input data dimensionalities. b. The performance on sciplex3 datasets in the i.i.d. setting with 1000 highly variable genes as input for $W_1$OT and $W_2$OT, other methods use 50 latent embeddings. The arrow head represents the better performance direction. Statistical significance is shown by Wilcoxon rank-sum test.