Counterfactual Generative Models for Time-Varying Treatments
Shenghao Wu, Wenbin Zhou, Minshuo Chen, Shixiang Zhu
TL;DR
This work tackles the challenge of estimating high-dimensional counterfactual outcomes under time-varying treatments by introducing a conditional generative framework that bypasses explicit density estimation. By leveraging inverse probability of treatment weighting (IPTW) within marginal structural generative models, the authors train flexible generators (diffusion and CVAE) to produce samples from a proxy conditional distribution that approximates the true counterfactual distribution $f_{\overline{a}}$. The framework demonstrates superior performance to several baselines on fully synthetic, semi-synthetic, and real COVID-19 data, particularly in capturing distributional features beyond the mean and in high-dimensional settings (e.g., $m=784$ TV-MNIST). The method enables uncertainty quantification and policy analysis by revealing multi-modal and region-specific counterfactual outcomes, with practical implications for public health decisions under time-varying interventions. Future work includes extending to continuous treatments, addressing potential violations of positivity or unmeasured confounding, and incorporating more advanced generative models to further improve sample fidelity.
Abstract
Estimating the counterfactual outcome of treatment is essential for decision-making in public health and clinical science, among others. Often, treatments are administered in a sequential, time-varying manner, leading to an exponentially increased number of possible counterfactual outcomes. Furthermore, in modern applications, the outcomes are high-dimensional and conventional average treatment effect estimation fails to capture disparities in individuals. To tackle these challenges, we propose a novel conditional generative framework capable of producing counterfactual samples under time-varying treatment, without the need for explicit density estimation. Our method carefully addresses the distribution mismatch between the observed and counterfactual distributions via a loss function based on inverse probability re-weighting, and supports integration with state-of-the-art conditional generative models such as the guided diffusion and conditional variational autoencoder. We present a thorough evaluation of our method using both synthetic and real-world data. Our results demonstrate that our method is capable of generating high-quality counterfactual samples and outperforms the state-of-the-art baselines.
