Table of Contents
Fetching ...

Consistent Optimal Transport with Empirical Conditional Measures

Piyushi Manupriya, Rachit Keerti Das, Sayantan Biswas, Saketha Nath Jagarlapudi

TL;DR

This work introduces Conditional Optimal Transport (COT), a framework to compare and transport conditional distributions when only joint samples are available and the conditioning variable can be continuous and have differing marginals. It regularizes the COT objective with kernel-based MMD terms to enforce alignment between transport plan marginals and empirical conditionals, proving consistency and a finite-sample bound of $O(1/m^{1/4})$ under mild assumptions. A key design choice is to factor the transport plan into $\pi_{Y|X}$ and $\pi_{Y'|Y,X}$, enabling simpler modelling and practical inference with either explicit or implicit models. The method is validated on synthetic tasks where the true Wasserstein/barycenters are known, and applied to cell population dynamics and prompt learning for few-shot classification, where it outperforms state-of-the-art baselines. Overall, COT provides a principled and scalable way to compare conditionals and yields practical gains in conditional generation and discriminative tasks.

Abstract

Given samples from two joint distributions, we consider the problem of Optimal Transportation (OT) between them when conditioned on a common variable. We focus on the general setting where the conditioned variable may be continuous, and the marginals of this variable in the two joint distributions may not be the same. In such settings, standard OT variants cannot be employed, and novel estimation techniques are necessary. Since the main challenge is that the conditional distributions are not explicitly available, the key idea in our OT formulation is to employ kernelized-least-squares terms computed over the joint samples, which implicitly match the transport plan's marginals with the empirical conditionals. Under mild conditions, we prove that our estimated transport plans, as a function of the conditioned variable, are asymptotically optimal. For finite samples, we show that the deviation in terms of our regularized objective is bounded by $O(1/m^{1/4})$, where $m$ is the number of samples. We also discuss how the conditional transport plan could be modelled using explicit probabilistic models as well as using implicit generative ones. We empirically verify the consistency of our estimator on synthetic datasets, where the optimal plan is analytically known. When employed in applications like prompt learning for few-shot classification and conditional-generation in the context of predicting cell responses to treatment, our methodology improves upon state-of-the-art methods.

Consistent Optimal Transport with Empirical Conditional Measures

TL;DR

This work introduces Conditional Optimal Transport (COT), a framework to compare and transport conditional distributions when only joint samples are available and the conditioning variable can be continuous and have differing marginals. It regularizes the COT objective with kernel-based MMD terms to enforce alignment between transport plan marginals and empirical conditionals, proving consistency and a finite-sample bound of under mild assumptions. A key design choice is to factor the transport plan into and , enabling simpler modelling and practical inference with either explicit or implicit models. The method is validated on synthetic tasks where the true Wasserstein/barycenters are known, and applied to cell population dynamics and prompt learning for few-shot classification, where it outperforms state-of-the-art baselines. Overall, COT provides a principled and scalable way to compare conditionals and yields practical gains in conditional generation and discriminative tasks.

Abstract

Given samples from two joint distributions, we consider the problem of Optimal Transportation (OT) between them when conditioned on a common variable. We focus on the general setting where the conditioned variable may be continuous, and the marginals of this variable in the two joint distributions may not be the same. In such settings, standard OT variants cannot be employed, and novel estimation techniques are necessary. Since the main challenge is that the conditional distributions are not explicitly available, the key idea in our OT formulation is to employ kernelized-least-squares terms computed over the joint samples, which implicitly match the transport plan's marginals with the empirical conditionals. Under mild conditions, we prove that our estimated transport plans, as a function of the conditioned variable, are asymptotically optimal. For finite samples, we show that the deviation in terms of our regularized objective is bounded by , where is the number of samples. We also discuss how the conditional transport plan could be modelled using explicit probabilistic models as well as using implicit generative ones. We empirically verify the consistency of our estimator on synthetic datasets, where the optimal plan is analytically known. When employed in applications like prompt learning for few-shot classification and conditional-generation in the context of predicting cell responses to treatment, our methodology improves upon state-of-the-art methods.
Paper Structure (33 sections, 3 theorems, 22 equations, 12 figures, 13 tables, 1 algorithm)

This paper contains 33 sections, 3 theorems, 22 equations, 12 figures, 13 tables, 1 algorithm.

Key Result

Lemma 1

Assuming $k$ is a normalized characteristic kernel, with probability at least $1-\delta$, we have

Figures (12)

  • Figure 1: Illustration of the proposed factorization and implicit modelling for learning the transport plan $\pi_{Y, Y'|X}(y, y'|x)$ through the factors $\pi_\theta(y|x) \pi_\psi(y'|y, x)$, parameterized by fixed-architecture neural networks \ref{['imp']}. $\eta, \ \eta'\sim \mathcal{N}(0, 1)$ denotes the noise input to the implicit models.
  • Figure 2: As $m\in\{100, 200, 400, 800\}$ increases from left to right, we plot the true Wasserstein distance in red and mark the means (in orange) and medians (in green) of the distances estimated using Tabak21 and the proposed COT estimator. The statistics are obtained from runs over multiple seeds. The corresponding MSEs are $\{ 245.530, 290.458, 89.715, 27.687\}$ and $\{22.711, 6.725, 8.052, 1.580\}$ respectively. It can be seen that the proposed COT objective converges to the true Wasserstein faster than Tabak21.
  • Figure 3: Barycenters shown on varying $\rho\in[0, 1]$ with colors interpolated between red and blue. Left: Conditional barycenter learnt by the proposed COT method. Right: Analytical barycenter.
  • Figure 4: For increasing values of $m$, we show box plots of the Wasserstein distance between the learnt barycenter, $B_x$, and the analytical barycenter. The corresponding MSEs are $\{22.399, 3.408, 3.964, 2.534, 1.687\}$ for Tabak21 and $\{4.441, 0.654, 0.353, 0.099, 0.058\}$ for the proposed COT estimator. It can be seen that the proposed COT-based barycenter converges to the true solution faster than Tabak21.
  • Figure 5: We pose learning prompts in few-shot classification as the conditional optimal transport problem. The figure shows our neural network diagram for learning conditional optimal transport plans.
  • ...and 7 more figures

Theorems & Definitions (5)

  • Lemma 1
  • Theorem 1
  • proof
  • Corollary (Restated from \citation{rad})
  • proof