Multiple Treatments Causal Effects Estimation with Task Embeddings and Balanced Representation Learning
Yuki Murakami, Takumi Hattori, Kohsuke Kubota
TL;DR
CISI-Net addresses the challenge of estimating single and interaction causal effects under multiple treatments by coupling a task embedding network that captures treatment similarity with a balanced representation learning framework to reduce selection bias. The approach enables parameter sharing across related treatment patterns and learns representations that are aligned across treatment patterns via an IPM-based penalty, improving counterfactual predictions for all treatment vectors. Across extensive simulations and real-world marketing data, CISI-Net outperforms baselines, including when latent covariates are present or interaction effects are absent, demonstrating robust applicability to complex decision-making problems. The work also outlines future directions for adaptive balancing, sample-efficient strategies, and doubly robust extensions to further enhance causal inference in multi-treatment settings.
Abstract
The simultaneous application of multiple treatments is increasingly common in many fields, such as healthcare and marketing. In such scenarios, it is important to estimate the single treatment effects and the interaction treatment effects that arise from treatment combinations. Previous studies have proposed using independent outcome networks with subnetworks for interactions, or combining task embedding networks that capture treatment similarity with variational autoencoders. However, these methods suffer from the lack of parameter sharing among related treatments, or the estimation of unnecessary latent variables reduces the accuracy of causal effect estimation. To address these issues, we propose a novel deep learning framework that incorporates a task embedding network and a representation learning network with the balancing penalty. The task embedding network enables parameter sharing across related treatment patterns because it encodes elements common to single effects and contributions specific to interaction effects. The representation learning network with the balancing penalty learns representations nonparametrically from observed covariates while reducing distances in representation distributions across different treatment patterns. This process mitigates selection bias and avoids model misspecification. Simulation studies demonstrate that the proposed method outperforms existing baselines, and application to real-world marketing datasets confirms the practical implications and utility of our framework.
