Table of Contents
Fetching ...

Multiple Treatments Causal Effects Estimation with Task Embeddings and Balanced Representation Learning

Yuki Murakami, Takumi Hattori, Kohsuke Kubota

TL;DR

CISI-Net addresses the challenge of estimating single and interaction causal effects under multiple treatments by coupling a task embedding network that captures treatment similarity with a balanced representation learning framework to reduce selection bias. The approach enables parameter sharing across related treatment patterns and learns representations that are aligned across treatment patterns via an IPM-based penalty, improving counterfactual predictions for all treatment vectors. Across extensive simulations and real-world marketing data, CISI-Net outperforms baselines, including when latent covariates are present or interaction effects are absent, demonstrating robust applicability to complex decision-making problems. The work also outlines future directions for adaptive balancing, sample-efficient strategies, and doubly robust extensions to further enhance causal inference in multi-treatment settings.

Abstract

The simultaneous application of multiple treatments is increasingly common in many fields, such as healthcare and marketing. In such scenarios, it is important to estimate the single treatment effects and the interaction treatment effects that arise from treatment combinations. Previous studies have proposed using independent outcome networks with subnetworks for interactions, or combining task embedding networks that capture treatment similarity with variational autoencoders. However, these methods suffer from the lack of parameter sharing among related treatments, or the estimation of unnecessary latent variables reduces the accuracy of causal effect estimation. To address these issues, we propose a novel deep learning framework that incorporates a task embedding network and a representation learning network with the balancing penalty. The task embedding network enables parameter sharing across related treatment patterns because it encodes elements common to single effects and contributions specific to interaction effects. The representation learning network with the balancing penalty learns representations nonparametrically from observed covariates while reducing distances in representation distributions across different treatment patterns. This process mitigates selection bias and avoids model misspecification. Simulation studies demonstrate that the proposed method outperforms existing baselines, and application to real-world marketing datasets confirms the practical implications and utility of our framework.

Multiple Treatments Causal Effects Estimation with Task Embeddings and Balanced Representation Learning

TL;DR

CISI-Net addresses the challenge of estimating single and interaction causal effects under multiple treatments by coupling a task embedding network that captures treatment similarity with a balanced representation learning framework to reduce selection bias. The approach enables parameter sharing across related treatment patterns and learns representations that are aligned across treatment patterns via an IPM-based penalty, improving counterfactual predictions for all treatment vectors. Across extensive simulations and real-world marketing data, CISI-Net outperforms baselines, including when latent covariates are present or interaction effects are absent, demonstrating robust applicability to complex decision-making problems. The work also outlines future directions for adaptive balancing, sample-efficient strategies, and doubly robust extensions to further enhance causal inference in multi-treatment settings.

Abstract

The simultaneous application of multiple treatments is increasingly common in many fields, such as healthcare and marketing. In such scenarios, it is important to estimate the single treatment effects and the interaction treatment effects that arise from treatment combinations. Previous studies have proposed using independent outcome networks with subnetworks for interactions, or combining task embedding networks that capture treatment similarity with variational autoencoders. However, these methods suffer from the lack of parameter sharing among related treatments, or the estimation of unnecessary latent variables reduces the accuracy of causal effect estimation. To address these issues, we propose a novel deep learning framework that incorporates a task embedding network and a representation learning network with the balancing penalty. The task embedding network enables parameter sharing across related treatment patterns because it encodes elements common to single effects and contributions specific to interaction effects. The representation learning network with the balancing penalty learns representations nonparametrically from observed covariates while reducing distances in representation distributions across different treatment patterns. This process mitigates selection bias and avoids model misspecification. Simulation studies demonstrate that the proposed method outperforms existing baselines, and application to real-world marketing datasets confirms the practical implications and utility of our framework.

Paper Structure

This paper contains 13 sections, 2 theorems, 16 equations, 5 figures, 2 tables.

Key Result

Proposition 1

Under Assumptions 1-3, the conditional average potential outcome $\mu(\boldsymbol{x}, \boldsymbol{t})$ is identified and is equal to the conditional expectation of the observed outcome as follows:

Figures (5)

  • Figure 1: The architecture of CISI-Net consists of three components: the representation learning network (yellow), the task embedding network (green), and the outcome prediction network (blue). The latent representation $\Phi(\boldsymbol{x})$ is concatenated with the task embedding vector $t_w(\boldsymbol{t})$ to predict the outcome $y$. The model is trained with two loss terms: the prediction loss $L_y$ and the balancing penalty $L_\Phi$ (red).
  • Figure 2: Relationship between treatment vector similarity and learned task embedding vector similarity in simulation dataset 1. For two treatment vectors $\boldsymbol{t}_1$ and $\boldsymbol{t}_2$, their corresponding task embedding vectors are denoted as $t_w(\boldsymbol{t}_1)$ and $t_w(\boldsymbol{t}_2)$. The x-axis shows the Jaccard similarity between $\boldsymbol{t}_1$ and $\boldsymbol{t}_2$, and the y-axis shows the cosine similarity between $t_w(\boldsymbol{t}_1)$ and $t_w(\boldsymbol{t}_2)$. Box plots summarize the distribution of cosine similarities for each Jaccard similarity value.
  • Figure 3: $\epsilon_{\mathrm{ASE}}$ and $\epsilon_{\mathrm{AIE}}$ for different values of the balancing penalty coefficient $\alpha$ in CISS-Net.
  • Figure 4: Estimation errors of CISI-Net under varying sample sizes on simulation dataset 1. The left panel shows $\epsilon_{\mathrm{ASE}}$, and the right panel shows $\epsilon_{\mathrm{AIE}}$. Each line represents the mean error computed over 100 random seeds.
  • Figure 5: Estimated single and interaction treatment effects obtained from two real-world marketing promotion datasets. The left panel corresponds to dataset 1 (CP1-CP3), and the right panel corresponds to dataset 2 (CP4 and CP5). Here, ${\tau_{\mathrm{ASE}}^{(d)}(k)}$ and ${\tau_{\mathrm{AIE}}^{(d)}(S)}$ are the estimated effects for single and multiple treatments in dataset $d$. The outcome is standardized prior to estimation.

Theorems & Definitions (4)

  • Proposition 1
  • Corollary 1
  • proof
  • proof