Table of Contents
Fetching ...

Optimal Transport for Treatment Effect Estimation

Hao Wang, Zhichao Chen, Jiajun Fan, Haoxuan Li, Tianqiao Liu, Weiming Liu, Quanyu Dai, Yichao Wang, Zhenhua Dong, Ruiming Tang

TL;DR

This work tackles conditional average treatment effect estimation from observational data by addressing two core challenges: mini-batch sampling effects (MSE) and unobserved confounders (UCE). It introduces Entire Space CounterFactual Regression (ESCFR), a stochastic optimal transport (OT) framework that combines a relaxed mass-preserving regularizer (RMPR) with a proximal factual outcome regularizer (PFOR) to robustly align treatment-group representations and counterfactuals. The method uses an end-to-end architecture with representation mapping $\psi$ and outcome mapping $\phi$, optimizing a loss that combines factual outcome risk and a generalized Sinkhorn discrepancy $\mathcal{L}_{ESCFR}^{\epsilon,\kappa,\gamma,\lambda}$. Empirical results on IHDP and ACIC show ESCFR outperforms state-of-the-art baselines, with ablations confirming the contributions of RMPR and PFOR to improved $PEHE$ and $AUUC$ in both in-sample and out-of-sample settings, highlighting practical impact for causal inference in healthcare, policy, and recommender systems.

Abstract

Estimating conditional average treatment effect from observational data is highly challenging due to the existence of treatment selection bias. Prevalent methods mitigate this issue by aligning distributions of different treatment groups in the latent space. However, there are two critical problems that these methods fail to address: (1) mini-batch sampling effects (MSE), which causes misalignment in non-ideal mini-batches with outcome imbalance and outliers; (2) unobserved confounder effects (UCE), which results in inaccurate discrepancy calculation due to the neglect of unobserved confounders. To tackle these problems, we propose a principled approach named Entire Space CounterFactual Regression (ESCFR), which is a new take on optimal transport in the context of causality. Specifically, based on the framework of stochastic optimal transport, we propose a relaxed mass-preserving regularizer to address the MSE issue and design a proximal factual outcome regularizer to handle the UCE issue. Extensive experiments demonstrate that our proposed ESCFR can successfully tackle the treatment selection bias and achieve significantly better performance than state-of-the-art methods.

Optimal Transport for Treatment Effect Estimation

TL;DR

This work tackles conditional average treatment effect estimation from observational data by addressing two core challenges: mini-batch sampling effects (MSE) and unobserved confounders (UCE). It introduces Entire Space CounterFactual Regression (ESCFR), a stochastic optimal transport (OT) framework that combines a relaxed mass-preserving regularizer (RMPR) with a proximal factual outcome regularizer (PFOR) to robustly align treatment-group representations and counterfactuals. The method uses an end-to-end architecture with representation mapping and outcome mapping , optimizing a loss that combines factual outcome risk and a generalized Sinkhorn discrepancy . Empirical results on IHDP and ACIC show ESCFR outperforms state-of-the-art baselines, with ablations confirming the contributions of RMPR and PFOR to improved and in both in-sample and out-of-sample settings, highlighting practical impact for causal inference in healthcare, policy, and recommender systems.

Abstract

Estimating conditional average treatment effect from observational data is highly challenging due to the existence of treatment selection bias. Prevalent methods mitigate this issue by aligning distributions of different treatment groups in the latent space. However, there are two critical problems that these methods fail to address: (1) mini-batch sampling effects (MSE), which causes misalignment in non-ideal mini-batches with outcome imbalance and outliers; (2) unobserved confounder effects (UCE), which results in inaccurate discrepancy calculation due to the neglect of unobserved confounders. To tackle these problems, we propose a principled approach named Entire Space CounterFactual Regression (ESCFR), which is a new take on optimal transport in the context of causality. Specifically, based on the framework of stochastic optimal transport, we propose a relaxed mass-preserving regularizer to address the MSE issue and design a proximal factual outcome regularizer to handle the UCE issue. Extensive experiments demonstrate that our proposed ESCFR can successfully tackle the treatment selection bias and achieve significantly better performance than state-of-the-art methods.
Paper Structure (45 sections, 9 theorems, 50 equations, 8 figures, 4 tables, 3 algorithms)

This paper contains 45 sections, 9 theorems, 50 equations, 8 figures, 4 tables, 3 algorithms.

Key Result

Theorem 3.1

Let $\psi$ and $\phi$ be the representation mapping and factual outcome mapping, respectively; $\hat{\mathbb{W}}_\psi$ be the group discrepancy at a mini-batch level. With the probability of at least $1-\delta$, we have: where $\epsilon^{T=1}_\mathrm{F}$ and $\epsilon^{T=0}_\mathrm{F}$ are the expected errors of factual outcome estimation, $N$ is the batch size, $\sigma^2_Y$ is the variance of ou

Figures (8)

  • Figure 1: Overview of handling treatment selection bias with ESCFR. The red and blue colors signify the treated and untreated groups, respectively. (a) The treatment selection bias manifests as a distribution shift between $X_1$ and $X_0$. The scatters and curves represent the units and the fitted outcome mappings. (b) ESCFR mitigates the selection bias by aligning units from different treatment groups in the representation space: $R=\psi(X)$, which enables $\phi_1$ and $\phi_0$ to generalize across groups.
  • Figure 2: Optimal transport plan (upper) and its geometric interpretation (down) in three cases, where the connection strength represents the transported mass. Different colors and vertical positions indicate different treatments and outcomes, respectively.
  • Figure 3: Causal graphs with (a) and w/o (b) the unconfoundedness assumption. The shaded node indicates the hidden confounder $X^\prime$.
  • Figure 4: Geometric interpretation of OT plan with RMPR under the outcome imbalance (upper) and outlier (down) settings. The dark area indicates the transported mass of a unit, i.e., marginal of the transport matrix $\pi$. The light area indicates the total mass.
  • Figure 5: Parameter sensitivity of ESCFR, where the lines and error bars indicate the mean values and 90% confidence intervals, respectively. (a) Impact of alignment strength ($\lambda$). (b) Impact of entropic regularization strength $\epsilon$. (c) Impact of PFOR strength $\gamma\ (\times10^{3})$. (d) Impact of RMPR strength $\kappa$.
  • ...and 3 more figures

Theorems & Definitions (25)

  • Definition 2.1
  • Definition 2.2
  • Definition 2.3
  • Definition 3.1
  • Theorem 3.1
  • Definition 3.2
  • Corollary 3.1
  • Definition A.1
  • Lemma A.1
  • Definition A.2
  • ...and 15 more