Table of Contents
Fetching ...

Conditional Generative Models are Sufficient to Sample from Any Causal Effect Estimand

Md Musfiqur Rahman, Matt Jordan, Murat Kocaoglu

TL;DR

This work shows how to sample from any identifiable interventional distribution given an arbitrary causal graph through a sequence of push-forward computations of conditional generative models, such as diffusion models.

Abstract

Causal inference from observational data plays critical role in many applications in trustworthy machine learning. While sound and complete algorithms exist to compute causal effects, many of them assume access to conditional likelihoods, which is difficult to estimate for high-dimensional (particularly image) data. Researchers have alleviated this issue by simulating causal relations with neural models. However, when we have high-dimensional variables in the causal graph along with some unobserved confounders, no existing work can effectively sample from the un/conditional interventional distributions. In this work, we show how to sample from any identifiable interventional distribution given an arbitrary causal graph through a sequence of push-forward computations of conditional generative models, such as diffusion models. Our proposed algorithm follows the recursive steps of the existing likelihood-based identification algorithms to train a set of feed-forward models, and connect them in a specific way to sample from the desired distribution. We conduct experiments on a Colored MNIST dataset having both the treatment ($X$) and the target variables ($Y$) as images and sample from $P(y|do(x))$. Our algorithm also enables us to conduct a causal analysis to evaluate spurious correlations among input features of generative models pre-trained on the CelebA dataset. Finally, we generate high-dimensional interventional samples from the MIMIC-CXR dataset involving text and image variables.

Conditional Generative Models are Sufficient to Sample from Any Causal Effect Estimand

TL;DR

This work shows how to sample from any identifiable interventional distribution given an arbitrary causal graph through a sequence of push-forward computations of conditional generative models, such as diffusion models.

Abstract

Causal inference from observational data plays critical role in many applications in trustworthy machine learning. While sound and complete algorithms exist to compute causal effects, many of them assume access to conditional likelihoods, which is difficult to estimate for high-dimensional (particularly image) data. Researchers have alleviated this issue by simulating causal relations with neural models. However, when we have high-dimensional variables in the causal graph along with some unobserved confounders, no existing work can effectively sample from the un/conditional interventional distributions. In this work, we show how to sample from any identifiable interventional distribution given an arbitrary causal graph through a sequence of push-forward computations of conditional generative models, such as diffusion models. Our proposed algorithm follows the recursive steps of the existing likelihood-based identification algorithms to train a set of feed-forward models, and connect them in a specific way to sample from the desired distribution. We conduct experiments on a Colored MNIST dataset having both the treatment () and the target variables () as images and sample from . Our algorithm also enables us to conduct a causal analysis to evaluate spurious correlations among input features of generative models pre-trained on the CelebA dataset. Finally, we generate high-dimensional interventional samples from the MIMIC-CXR dataset involving text and image variables.
Paper Structure (43 sections, 20 theorems, 16 equations, 24 figures, 4 tables, 7 algorithms)

This paper contains 43 sections, 20 theorems, 16 equations, 24 figures, 4 tables, 7 algorithms.

Key Result

Theorem 3.2

Under Assumptions: i) the SCM is semi-Markovian, ii) we have access to the ADMG, iii) $P(\mathbf{V})$ is strictly positive and iv) trained generative models sample from correct distributions, ID-GEN and IDC-GEN are sound and complete to sample from any identifiable $P_x(y)$ and $P_{x}(y|z)$.

Figures (24)

  • Figure 1: (Top: x-ray to report generation task) (a) $\text{do}(X=x)$ removes the $X\leftrightarrow R$ bias and makes the generation of $R$ domain invariant. $P(r|\text{do}(x))$ is factorized into c-factors and (b) conditional models ($\{M_{V_k}\}_{k}$) are trained for each factor (shown as boxes). (c) The intervened value $X=x$ is propagated through the merged network and samples from the $P(r|\text{do}(x))$ are generated.
  • Figure 2: $\leftrightarrow:$Unobserved. Left blue samples from $P_{x,w_2}(w_1,y)$$= P(w_1|x)$$P(y|x,w_1,w_2)$. Right blue samples from $P_{x, w_1}(w_2)$$= \sum_{x'} P(x')$$P(w_2|x',w_1)$. Joint network samples from $P_x(y)$.
  • Figure 3: (Left: top-down) $P_{w_1}(y)$ is factorized into $P_{w_1, x,y }(w_2)$, $P_{w_1, w_2, y}(x)$ and $P_{w_1, w_2, x}(y)$ (Step 4). Steps 7, 2, 6 is shown for $P_{w_1, w_2, x}(y)$ only. (Right: bottom-up) we combine the sampling networks of each c-factor. For any $\text{do}(W_1=w_1)$, we use $\mathcal{H}$ to get samples from $P_{w_1}(y)$.
  • Figure 4: (Left:) Causal graph with color and thickness as unobserved. (Center:) FID scores (lower the better) of each algorithm and images generated from them. (Right:) Likelihood calculated from the $P_{x}(y)$ images generated by each algorithm. We closely reflect the true $P_{x}(y)$ with low TVD.
  • Figure 5: i) Graph and sampling network for $P_{Male}(I_2)$. ii) For both causal and non-causal attributes, EGSDE shows high correlation.
  • ...and 19 more figures

Theorems & Definitions (41)

  • Definition 3.1: Sampling network, $\mathcal{H}$
  • Theorem 3.2
  • Example C.1: Cyclic dependency
  • Definition D.1
  • Definition D.2: Recursive call
  • Lemma D.7: c-component factorization tian2002general
  • Definition D.8
  • Proposition D.9
  • proof
  • Proposition D.10
  • ...and 31 more