Table of Contents
Fetching ...

Distribution-Conditioned Transport

Nic Fishman, Gokul Gowri, Paolo L. B. Fischer, Marinka Zitnik, Omar Abudayyeh, Jonathan Gootenberg

TL;DR

Distribution-conditioned transport (DCT), a framework that conditions transport maps on learned embeddings of source and target distributions, enabling generalization to unseen distribution pairs, is introduced.

Abstract

Learning a transport model that maps a source distribution to a target distribution is a canonical problem in machine learning, but scientific applications increasingly require models that can generalize to source and target distributions unseen during training. We introduce distribution-conditioned transport (DCT), a framework that conditions transport maps on learned embeddings of source and target distributions, enabling generalization to unseen distribution pairs. DCT also allows semi-supervised learning for distributional forecasting problems: because it learns from arbitrary distribution pairs, it can leverage distributions observed at only one condition to improve transport prediction. DCT is agnostic to the underlying transport mechanism, supporting models ranging from flow matching to distributional divergence-based models (e.g. Wasserstein, MMD). We demonstrate the practical performance benefits of DCT on synthetic benchmarks and four applications in biology: batch effect transfer in single-cell genomics, perturbation prediction from mass cytometry data, learning clonal transcriptional dynamics in hematopoiesis, and modeling T-cell receptor sequence evolution.

Distribution-Conditioned Transport

TL;DR

Distribution-conditioned transport (DCT), a framework that conditions transport maps on learned embeddings of source and target distributions, enabling generalization to unseen distribution pairs, is introduced.

Abstract

Learning a transport model that maps a source distribution to a target distribution is a canonical problem in machine learning, but scientific applications increasingly require models that can generalize to source and target distributions unseen during training. We introduce distribution-conditioned transport (DCT), a framework that conditions transport maps on learned embeddings of source and target distributions, enabling generalization to unseen distribution pairs. DCT also allows semi-supervised learning for distributional forecasting problems: because it learns from arbitrary distribution pairs, it can leverage distributions observed at only one condition to improve transport prediction. DCT is agnostic to the underlying transport mechanism, supporting models ranging from flow matching to distributional divergence-based models (e.g. Wasserstein, MMD). We demonstrate the practical performance benefits of DCT on synthetic benchmarks and four applications in biology: batch effect transfer in single-cell genomics, perturbation prediction from mass cytometry data, learning clonal transcriptional dynamics in hematopoiesis, and modeling T-cell receptor sequence evolution.
Paper Structure (142 sections, 7 theorems, 53 equations, 6 figures, 21 tables, 1 algorithm)

This paper contains 142 sections, 7 theorems, 53 equations, 6 figures, 21 tables, 1 algorithm.

Key Result

Proposition 1.2

Fix $m\in\mathbb N$ and let $\mathcal{E}_m:\mathcal{X}^m\to\mathbb R^d$ be permutation invariant. Then there exists a measurable map $\phi_m$ defined on the set of empirical measures $\{\widehat{P}_m:\,S_m\in\mathcal{X}^m\}$ such that

Figures (6)

  • Figure 1: The distribution-conditioned transport framework. A source distribution (teal) is pushed to a target distribution (purple) by a transport model $\mathcal{T}$ which is conditioned on distribution embeddings learned by an encoder $\mathcal{E}$. The learned transport map is universal in the sense that any distribution can in principle be pushed to any distribution by conditioning on the corresponding source and target embeddings.
  • Figure 2: Bivariate normal transport error landscape.$K$-to-$K$ (top) vs. any-to-any (bottom), showing $W_2$ distance for targets $\mu \in [0,5]^2$ across $K$. The $K$-to-$K$ model predicts via nearest training distribution, yielding Voronoi-like error patterns; the any-to-any model embeds targets directly and achieves uniformly lower error across all $K$.
  • Figure 3: Unsupervised transport model generalization. The gap between $K$-to-$K$ based embedding and any-to-any distribution encoders across generator families, evaluated on in-distribution (IID) and out-of-distribution (OOD) test sets. Positive values indicate the distribution encoder achieves lower transport cost. At low $K$, the embedding encoder outperforms on IID data while the distribution encoder shows stronger OOD generalization.
  • Figure 4: Batch effect transfer for two held-out donor pairs (rows). Blue contours represent the reference density ($\sim 3\cdot10^5$ cells) computed from all 56 donors. (left) Ground truth cells from the target donor. (center) predictions from DCT and (right) $K$-to-$K$ model. DCT predictions more closely match ground truth than $K$-to-$K$ predictions.
  • Figure 5: Semi-supervised transport model generalization. Supervised models are trained on distributions with $\|\mu\|_\infty \leq 2.5$ (shaded) and evaluated across $\|\mu\|_\infty \leq 5$. Supervised models (red) degrade sharply outside the training support (except flow matching models), while semi-supervised models (blue) maintain stable performance approaching the oracle (green). Results shown for multivariate normal (left) and Gaussian mixture (right).
  • ...and 1 more figures

Theorems & Definitions (15)

  • Definition 1.1: Permutation and proportional invariance
  • Proposition 1.2: Permutation invariance $\Rightarrow$ factorization
  • proof
  • Proposition 1.3: Permutation + proportional invariance $\Rightarrow$ a single functional
  • proof
  • Theorem 1.4: Encoder CLT
  • proof
  • Theorem 1.5: Plug-in loss CLTs for DCT
  • proof
  • Corollary 1.6: Mean consistency and (asymptotic) unbiasedness of plug-in training
  • ...and 5 more