Table of Contents
Fetching ...

Pairwise Optimal Transports for Training All-to-All Flow-Based Condition Transfer Model

Kotaro Ikeda, Masanori Koyama, Jinzhe Zhang, Kohei Hayashi, Kenji Fukumizu

TL;DR

This work addresses learning all-to-all transfers between conditional distributions under continuous and sparse conditioning by introducing A2A-FM, a flow-based framework that learns pairwise transports across all condition pairs. A novel coupling objective builds minibatch-based permutations to approximate pairwise OT, with a theoretical guarantee that, in the infinite-sample limit, the learned couplings converge to $W_2^2(P_{c_1},P_{c_2})$ for almost every $(c_1,c_2)$. The approach scales to non-grouped data and large, high-dimensional settings, and demonstrates state-of-the-art performance in molecular property transfer and high-dimensional image-attribute transfer, while maintaining favorable computational properties relative to prior multimarginal methods. The work provides practical tools for conditional generation and design tasks where continuous conditions are central, with code available at the provided repository.

Abstract

In this paper, we propose a flow-based method for learning all-to-all transfer maps among conditional distributions that approximates pairwise optimal transport. The proposed method addresses the challenge of handling the case of continuous conditions, which often involve a large set of conditions with sparse empirical observations per condition. We introduce a novel cost function that enables simultaneous learning of optimal transports for all pairs of conditional distributions. Our method is supported by a theoretical guarantee that, in the limit, it converges to the pairwise optimal transports among infinite pairs of conditional distributions. The learned transport maps are subsequently used to couple data points in conditional flow matching. We demonstrate the effectiveness of this method on synthetic and benchmark datasets, as well as on chemical datasets in which continuous physical properties are defined as conditions. The code for this project can be found at https://github.com/kotatumuri-room/A2A-FM

Pairwise Optimal Transports for Training All-to-All Flow-Based Condition Transfer Model

TL;DR

This work addresses learning all-to-all transfers between conditional distributions under continuous and sparse conditioning by introducing A2A-FM, a flow-based framework that learns pairwise transports across all condition pairs. A novel coupling objective builds minibatch-based permutations to approximate pairwise OT, with a theoretical guarantee that, in the infinite-sample limit, the learned couplings converge to for almost every . The approach scales to non-grouped data and large, high-dimensional settings, and demonstrates state-of-the-art performance in molecular property transfer and high-dimensional image-attribute transfer, while maintaining favorable computational properties relative to prior multimarginal methods. The work provides practical tools for conditional generation and design tasks where continuous conditions are central, with code available at the provided repository.

Abstract

In this paper, we propose a flow-based method for learning all-to-all transfer maps among conditional distributions that approximates pairwise optimal transport. The proposed method addresses the challenge of handling the case of continuous conditions, which often involve a large set of conditions with sparse empirical observations per condition. We introduce a novel cost function that enables simultaneous learning of optimal transports for all pairs of conditional distributions. Our method is supported by a theoretical guarantee that, in the limit, it converges to the pairwise optimal transports among infinite pairs of conditional distributions. The learned transport maps are subsequently used to couple data points in conditional flow matching. We demonstrate the effectiveness of this method on synthetic and benchmark datasets, as well as on chemical datasets in which continuous physical properties are defined as conditions. The code for this project can be found at https://github.com/kotatumuri-room/A2A-FM

Paper Structure

This paper contains 28 sections, 5 theorems, 55 equations, 18 figures, 6 tables, 3 algorithms.

Key Result

Proposition 3.1

Let $\Pi^*_{\beta}$ be the joint distribution on $(\mathcal{X} \times \mathcal{C}) \times (\mathcal{X} \times \mathcal{C})$ defined by the coupling $\pi_\beta^*$ that minimizes eq:costA2A, that is Then, for any sequence $\beta_k \to \infty$, there exists an increasing sequence of the sample size $N_k$ such that $\Pi^*_{\beta_k}$ converges to $\Pi^*$ for which $\Pi^*(\cdot,\cdot \mid c_1,c_2)$, th

Figures (18)

  • Figure 1: (a) The task is to transport $x_{\rm src} \sim P_{c_{\rm src}}$ to generate $x_{\rm targ} \sim P_{c_{\rm targ}}$ for arbitrary ($c_{\rm src},c_{\rm targ}$) pair, where $P_c$ denotes the conditional distribution. Red and blue arrows respectively represent the case of $(c_{\rm src}, c_{\rm targ}) = (c^{(1)}, c^{(2)})$ and $(c_{\rm src}, c_{\rm targ}) = (c^{(3)}, c^{(4)})$. (b) Left: Grouped data is the type of dataset that can be grouped into subsets $D_{c^{(i)}}$ of large size, whose members are i.i.d. samples from $P_{c^{(i)}}$. Many condition transfer methods including Multimariginal Stochastic Interpolants (SI) albergo2023multimarginal and Extended Flow Matching (EFM) isobe2024extended leverage this data format. Right: In non-grouped data, a sample corresponding to a given condition can be unique. Proposed method, A2A-FM, can learn condition transfer on both cases in the form of pairwise optimal transport. (see Section \ref{['sec:related']} for comparision with related works).
  • Figure 2: (a) Batches $B_1, B_2$ drawn independently from $P$ on $\mathcal{X} \times \mathcal{C}$, where $\mathcal{C}=\{\mathpzc{a}, \mathpzc{b}, \mathpzc{c}\}$. (b) Couplings between $B_1$ and $B_2$ by \ref{['eq:costA2A']}. With large $\beta$, the cost favors $\pi$ such that $(c^{(i)}_1, c^{(i)}_2) = (c^{\pi(i)}_1, c^{\pi(i)}_2)$.
  • Figure 3: (a) Results for grouped data. The sample size was $10^3$ in each of 3 conditions $\{c^{(0)},c^{(1)},c^{(2)}\}$ and $\beta=10^4$. (b) Results for non-grouped data. The gray points in the background show samples from the training dataset. The presented pairwise OT is a numerical approximation by flamary2021pot. The red lines in the right column shows the bins for training Multimarginal SI.
  • Figure 4: Sampling Efficiency of A2A-FM and the partial diffusion model of kaufman2024coati.
  • Figure 5: Sampling Efficiency Curve for LogP-TPSA benchmark. See Appendix \ref{['sec:logp-tpsa-appendix']} and Table \ref{['tab:AUC_values']} for notation. $K$ is the number of discretization bins.
  • ...and 13 more figures

Theorems & Definitions (8)

  • Proposition 3.1: Informal
  • Proposition A.1
  • Theorem A.2
  • proof
  • Proposition A.3
  • proof
  • Proposition A.4
  • proof