Table of Contents
Fetching ...

Wasserstein Parallel Transport for Predicting the Dynamics of Statistical Systems

Tristan Luca Saidi, Gonzalo Mena, Larry Wasserman, Florian Gunsilius

Abstract

Many scientific systems, such as cellular populations or economic cohorts, are naturally described by probability distributions that evolve over time. Predicting how such a system would have evolved under different forces or initial conditions is fundamental to causal inference, domain adaptation, and counterfactual prediction. However, the space of distributions often lacks the vector space structure on which classical methods rely. To address this, we introduce a general notion of parallel dynamics at a distributional level. We base this principle on parallel transport of tangent dynamics along optimal transport geodesics and call it ``Wasserstein Parallel Trends''. By replacing the vector subtraction of classic methods with geodesic parallel transport, we can provide counterfactual comparisons of distributional dynamics in applications such as causal inference, domain adaptation, and batch-effect correction in experimental settings. The main mathematical contribution is a novel notion of fanning scheme on the Wasserstein manifold that allows us to efficiently approximate parallel transport along geodesics while also providing the first theoretical guarantees for parallel transport in the Wasserstein space. We also show that Wasserstein Parallel Trends recovers the classic parallel trends assumption for averages as a special case and derive closed-form parallel transport for Gaussian measures. We deploy the method on synthetic data and two single-cell RNA sequencing datasets to impute gene-expression dynamics across biological systems.

Wasserstein Parallel Transport for Predicting the Dynamics of Statistical Systems

Abstract

Many scientific systems, such as cellular populations or economic cohorts, are naturally described by probability distributions that evolve over time. Predicting how such a system would have evolved under different forces or initial conditions is fundamental to causal inference, domain adaptation, and counterfactual prediction. However, the space of distributions often lacks the vector space structure on which classical methods rely. To address this, we introduce a general notion of parallel dynamics at a distributional level. We base this principle on parallel transport of tangent dynamics along optimal transport geodesics and call it ``Wasserstein Parallel Trends''. By replacing the vector subtraction of classic methods with geodesic parallel transport, we can provide counterfactual comparisons of distributional dynamics in applications such as causal inference, domain adaptation, and batch-effect correction in experimental settings. The main mathematical contribution is a novel notion of fanning scheme on the Wasserstein manifold that allows us to efficiently approximate parallel transport along geodesics while also providing the first theoretical guarantees for parallel transport in the Wasserstein space. We also show that Wasserstein Parallel Trends recovers the classic parallel trends assumption for averages as a special case and derive closed-form parallel transport for Gaussian measures. We deploy the method on synthetic data and two single-cell RNA sequencing datasets to impute gene-expression dynamics across biological systems.
Paper Structure (40 sections, 342 equations, 8 figures, 3 tables, 2 algorithms)

This paper contains 40 sections, 342 equations, 8 figures, 3 tables, 2 algorithms.

Figures (8)

  • Figure 1: Visualizations of Wasserstein Parallel Transport with Gaussian measures in $\mathbb{R}^2$. The left panel illustrates how parallel transport captures deformations to the covariance, while the right panel illustrates how it captures changes in the mean.
  • Figure 2: Visualization of approximate Wasserstein parallel transport, where the approximate tangent vector is given by the map $p \mapsto {\operatorname{PT}}_{\exp_p(\nabla\varphi(p))}(v(p))$ (left), and exact Wasserstein parallel transport, where the true tangent vector is given by the map $p \mapsto({\operatorname{PT}}_{\mathbf{exp}_{\nu}(\nabla\varphi)}(v))(p)$ (right).
  • Figure 3: Estimated projection of a non-conservative vector field onto the space of gradient fields using the procedure described in \ref{['sec: helmholtz decomp']}
  • Figure 4: Visualization of counterfactual dynamics prediction procedure described in \ref{['alg: W trajectory reconstruction']}.
  • Figure 5: Difference-in-differences diagram illustrating the parallel trends assumption in the Euclidean case (left) and the Wasserstein case (right). The goal is to reconstruct the counterfactual (denoted $Y_t(0)$ on the left and $\mu_t^*$ on the right).
  • ...and 3 more figures

Theorems & Definitions (5)

  • proof
  • proof
  • proof
  • proof
  • proof