Counterfactual Generative Modeling with Variational Causal Inference
Yulun Wu, Louie McConnell, Claudia Iriondo
TL;DR
This paper tackles the challenge of predicting high-dimensional counterfactual outcomes under interventions when covariates are limited. It introduces Variational Causal Inference (VCI), a framework that derives an ELBO for the individual-level counterfactual likelihood $p(Y'|Y,X,T,T')$ using a latent exogenous noise vector $Z$ and twin-like causal structures to enable end-to-end counterfactual supervision without counterfactual samples. A key contribution is the latent divergence term that enforces disentanglement between $Z$ and the treatment $T$, yielding more faithful counterfactual generations and improved identifiability. The authors also develop a robust marginal estimator for causal parameters and demonstrate superior performance of VCI over state-of-the-art models on both vector-valued (single-cell perturbations) and image-valued (Morpho-MNIST, CelebA-HQ) tasks, including ablations that highlight the role of counterfactual supervision and latent regularization. Together, these advances enable reliable, high-fidelity counterfactual inference in domains with rich outcomes and limited covariates, with practical impact in biology and computer vision.
Abstract
Estimating an individual's counterfactual outcomes under interventions is a challenging task for traditional causal inference and supervised learning approaches when the outcome is high-dimensional (e.g. gene expressions, facial images) and covariates are relatively limited. In this case, to predict one's outcomes under counterfactual treatments, it is crucial to leverage individual information contained in the observed outcome in addition to the covariates. Prior works using variational inference in counterfactual generative modeling have been focusing on neural adaptations and model variants within the conditional variational autoencoder formulation, which we argue is fundamentally ill-suited to the notion of counterfactual in causal inference. In this work, we present a novel variational Bayesian causal inference framework and its theoretical backings to properly handle counterfactual generative modeling tasks, through which we are able to conduct counterfactual supervision end-to-end during training without any counterfactual samples, and encourage disentangled exogenous noise abduction that aids the correct identification of causal effect in counterfactual generations. In experiments, we demonstrate the advantage of our framework compared to state-of-the-art models in counterfactual generative modeling on multiple benchmarks.
