Table of Contents
Fetching ...

Counterfactual Generative Modeling with Variational Causal Inference

Yulun Wu, Louie McConnell, Claudia Iriondo

TL;DR

This paper tackles the challenge of predicting high-dimensional counterfactual outcomes under interventions when covariates are limited. It introduces Variational Causal Inference (VCI), a framework that derives an ELBO for the individual-level counterfactual likelihood $p(Y'|Y,X,T,T')$ using a latent exogenous noise vector $Z$ and twin-like causal structures to enable end-to-end counterfactual supervision without counterfactual samples. A key contribution is the latent divergence term that enforces disentanglement between $Z$ and the treatment $T$, yielding more faithful counterfactual generations and improved identifiability. The authors also develop a robust marginal estimator for causal parameters and demonstrate superior performance of VCI over state-of-the-art models on both vector-valued (single-cell perturbations) and image-valued (Morpho-MNIST, CelebA-HQ) tasks, including ablations that highlight the role of counterfactual supervision and latent regularization. Together, these advances enable reliable, high-fidelity counterfactual inference in domains with rich outcomes and limited covariates, with practical impact in biology and computer vision.

Abstract

Estimating an individual's counterfactual outcomes under interventions is a challenging task for traditional causal inference and supervised learning approaches when the outcome is high-dimensional (e.g. gene expressions, facial images) and covariates are relatively limited. In this case, to predict one's outcomes under counterfactual treatments, it is crucial to leverage individual information contained in the observed outcome in addition to the covariates. Prior works using variational inference in counterfactual generative modeling have been focusing on neural adaptations and model variants within the conditional variational autoencoder formulation, which we argue is fundamentally ill-suited to the notion of counterfactual in causal inference. In this work, we present a novel variational Bayesian causal inference framework and its theoretical backings to properly handle counterfactual generative modeling tasks, through which we are able to conduct counterfactual supervision end-to-end during training without any counterfactual samples, and encourage disentangled exogenous noise abduction that aids the correct identification of causal effect in counterfactual generations. In experiments, we demonstrate the advantage of our framework compared to state-of-the-art models in counterfactual generative modeling on multiple benchmarks.

Counterfactual Generative Modeling with Variational Causal Inference

TL;DR

This paper tackles the challenge of predicting high-dimensional counterfactual outcomes under interventions when covariates are limited. It introduces Variational Causal Inference (VCI), a framework that derives an ELBO for the individual-level counterfactual likelihood using a latent exogenous noise vector and twin-like causal structures to enable end-to-end counterfactual supervision without counterfactual samples. A key contribution is the latent divergence term that enforces disentanglement between and the treatment , yielding more faithful counterfactual generations and improved identifiability. The authors also develop a robust marginal estimator for causal parameters and demonstrate superior performance of VCI over state-of-the-art models on both vector-valued (single-cell perturbations) and image-valued (Morpho-MNIST, CelebA-HQ) tasks, including ablations that highlight the role of counterfactual supervision and latent regularization. Together, these advances enable reliable, high-fidelity counterfactual inference in domains with rich outcomes and limited covariates, with practical impact in biology and computer vision.

Abstract

Estimating an individual's counterfactual outcomes under interventions is a challenging task for traditional causal inference and supervised learning approaches when the outcome is high-dimensional (e.g. gene expressions, facial images) and covariates are relatively limited. In this case, to predict one's outcomes under counterfactual treatments, it is crucial to leverage individual information contained in the observed outcome in addition to the covariates. Prior works using variational inference in counterfactual generative modeling have been focusing on neural adaptations and model variants within the conditional variational autoencoder formulation, which we argue is fundamentally ill-suited to the notion of counterfactual in causal inference. In this work, we present a novel variational Bayesian causal inference framework and its theoretical backings to properly handle counterfactual generative modeling tasks, through which we are able to conduct counterfactual supervision end-to-end during training without any counterfactual samples, and encourage disentangled exogenous noise abduction that aids the correct identification of causal effect in counterfactual generations. In experiments, we demonstrate the advantage of our framework compared to state-of-the-art models in counterfactual generative modeling on multiple benchmarks.

Paper Structure

This paper contains 62 sections, 9 theorems, 56 equations, 13 figures, 8 tables.

Key Result

Theorem 1

Suppose a collection of random variables $W$ follows a causal structure defined by the Bayesian network in Figure causal_diagram. Then $J(D) = \log [ p (Y' | Y, X, T, T') ]$ has the following variational lower bound: where $D [ p \parallel q ] = \log p - \log q$.

Figures (13)

  • Figure 1: The causal diagram where variables are generated from the following SCM: $X=f_X(U_X)$; $T=f_T(X, U_T)$, $T'=f_T(X, U'_T)$ where $U_T$, $U'_T$ are i.i.d.; $Z=f_Z(X, U_Y)$; $Y=f_Y(Z, T, \epsilon_Y)$, $Y'=f_Y(Z, T', \epsilon'_Y)$ where $\epsilon_Y$, $\epsilon'_Y$ are i.i.d., independent of all $U$s, and are constants w.p. 1 if the consistency assumption is assumed. Endogenous variables follow the Bayesian network on the left under the ignorability assumption, with permissibly one additional edge from either $Z$ or $T$ to $X$ if $U_Y$ or $U_T$ is dependent of $U_X$. We call $O=(X, T, Y)$ observed variables, $B=(X, T, T', Y)$ sample variables ($T'$ are sampled for model training and evaluation), $D=(X, T, T', Y, Y')$ full data variables and $W=(Z, X, T, T', Y, Y')$ all variables. White nodes are observed, light grey nodes are assigned during training and inference, dark grey nodes are unobserved.
  • Figure 2: Model workflow of the Variational Causal Inference (VCI) framework. In a forward pass, the encoding model takes outcome $y$ as well as its treatments $t$ and covariates $x$ (if any) as inputs and attains latent feature $z$; $(z, t)$ and $(z, t')$ where $t'$ is the counterfactual treatments are separately passed into the decoding model to attain reconstruction $y_{\theta, \phi}$ and counterfactual construction $y'_{\theta, \phi}$, for which reconstruction loss and counterfactual supervision loss are evaluated; $y'_{\theta, \phi}$ is then passed back into the encoding model along with $t'$ and $x$ to attain $z$ again, which is encouraged to match the latent feature $z$ previously acquired from encoding the factual outcome $y$ and treatment $t$ through the latent divergence term. The neural architectures of the latent recognition and outcome construction models for image generation tasks can be found in Appendix \ref{['sec:architecture']}.
  • Figure 3: Ablation Study: the error of counterfactual prediction across epochs during the training of HAE (VCI without counterfactual supervision and latent divergence), SAE (VCI without latent divergence) and VCI over five independent runs. Note that VCI with latent divergence but without counterfactual supervision does not make logical sense, but for the completeness of the ablation study, we present the results for such setting in Appendix \ref{['experiment:complete-ablation']}.
  • Figure 4: Illustration of the latent divergence term's long-term impact. Samples are drawn from the last epoch of SAE and VCI on $do$(intensity). The left most image of each set is the original image. More sampled results from VCI on $do$(thickness) and $do$(intensity) can be found in Figure \ref{['fig:morpho-mnist-sample']}.Results from intervening digit i.e. $do$(digit) can be found in Figure \ref{['fig:morpho-mnist-sample-digit']}.
  • Figure 5: Results on the test set of CelebA-HQ.
  • ...and 8 more figures

Theorems & Definitions (23)

  • Theorem 1
  • Lemma 1
  • Proposition 1
  • Proposition 2
  • Corollary 1
  • Corollary 2
  • Proposition 3
  • Definition 1: Oracle Consistency
  • Definition 2: Oracle Restrictiveness
  • Definition 3: Disentanglement
  • ...and 13 more