Table of Contents
Fetching ...

Variational Causal Inference

Yulun Wu, Layne C. Price, Zichen Wang, Vassilis N. Ioannidis, Robert A. Barton, George Karypis

TL;DR

The paper addresses high-dimensional individualized counterfactual outcomes with limited covariates by introducing Variational Causal Inference (VCI), a semi-autoencoding, variational Bayesian framework. It jointly models the factual outcome via a latent representation $Z$ and the counterfactual via a covariate-specific distribution $p(Y'|X,T')$, deriving an ELBO-based objective that couples reconstruction and counterfactual likelihood while regularizing latent distributions. An efficient influence-function-based estimator for covariate-specific and marginal effects is provided, enabling asymptotically efficient estimation of ${oldsymbol{ ext Ψ}(p)} = ext{E}_p[Y'_{ ext{do}(T'=a)}]$ and related quantities. Empirical evaluation on single-cell perturbation data demonstrates that VCI outperforms state-of-the-art baselines in out-of-distribution predictions and in robust marginal estimations, highlighting its practical impact for high-dimensional outcomes and limited covariates.

Abstract

Estimating an individual's potential outcomes under counterfactual treatments is a challenging task for traditional causal inference and supervised learning approaches when the outcome is high-dimensional (e.g. gene expressions, impulse responses, human faces) and covariates are relatively limited. In this case, to construct one's outcome under a counterfactual treatment, it is crucial to leverage individual information contained in its observed factual outcome on top of the covariates. We propose a deep variational Bayesian framework that rigorously integrates two main sources of information for outcome construction under a counterfactual treatment: one source is the individual features embedded in the high-dimensional factual outcome; the other source is the response distribution of similar subjects (subjects with the same covariates) that factually received this treatment of interest.

Variational Causal Inference

TL;DR

The paper addresses high-dimensional individualized counterfactual outcomes with limited covariates by introducing Variational Causal Inference (VCI), a semi-autoencoding, variational Bayesian framework. It jointly models the factual outcome via a latent representation and the counterfactual via a covariate-specific distribution , deriving an ELBO-based objective that couples reconstruction and counterfactual likelihood while regularizing latent distributions. An efficient influence-function-based estimator for covariate-specific and marginal effects is provided, enabling asymptotically efficient estimation of and related quantities. Empirical evaluation on single-cell perturbation data demonstrates that VCI outperforms state-of-the-art baselines in out-of-distribution predictions and in robust marginal estimations, highlighting its practical impact for high-dimensional outcomes and limited covariates.

Abstract

Estimating an individual's potential outcomes under counterfactual treatments is a challenging task for traditional causal inference and supervised learning approaches when the outcome is high-dimensional (e.g. gene expressions, impulse responses, human faces) and covariates are relatively limited. In this case, to construct one's outcome under a counterfactual treatment, it is crucial to leverage individual information contained in its observed factual outcome on top of the covariates. We propose a deep variational Bayesian framework that rigorously integrates two main sources of information for outcome construction under a counterfactual treatment: one source is the individual features embedded in the high-dimensional factual outcome; the other source is the response distribution of similar subjects (subjects with the same covariates) that factually received this treatment of interest.
Paper Structure (20 sections, 2 theorems, 9 equations, 3 figures, 2 tables)

This paper contains 20 sections, 2 theorems, 9 equations, 3 figures, 2 tables.

Key Result

Theorem 1

Suppose random vector $W=(X, Z, T, T', Y, Y')$ follows a causal structure defined by the Bayesian network in Figure causal_diagram. Then $\log \left[ p (Y' | Y, X, T, T') \right]$ has the following variational lower bound: where $D [ p \parallel q ] = \log p - \log q$.

Figures (3)

  • Figure 1: The causal relation diagram. Each individual has a covariate-dependant feature state $Z$. Treatment $T$ (or counterfactual treatment $T'$) along with $Z$ determines outcome $Y$ (or counterfactual outcome $Y'$). In the causal diagram, white nodes are observed and dark grey nodes are unobserved; dashed edges exist if the data were not generated from a completely randomized trial.
  • Figure 2: Dependency structure of the encoder and decoder. White nodes are observed, light grey node is assigned (sampled) and dark grey nodes are inferred; dashed edge is optional. Note that the decoder estimates the conditional outcome distribution of $Y'$, in which case $T'$ need not necessarily be sampled according to a certain true distribution $p(T' | X)$ during optimization.
  • Figure 3: GANITE's counterfactual generator. $t_{cf}$ is a random sample of $\bm{t}$, passed into the generator as a part of the input $(x, t_{cf}, y_f)$, separately from input $(x, t_f, y_f)$ of the factual generation.

Theorems & Definitions (4)

  • Theorem 1
  • Theorem 2
  • proof
  • proof