Table of Contents
Fetching ...

Modular Learning of Deep Causal Generative Models for High-dimensional Causal Inference

Md Musfiqur Rahman, Murat Kocaoglu

TL;DR

The algorithm, Modular-DCM is the first algorithm that, given the causal structure, uses adversarial training to learn the network weights, and can make use of pre-trained models to provably sample from any identifiable causal query in the presence of latent confounders.

Abstract

Sound and complete algorithms have been proposed to compute identifiable causal queries using the causal structure and data. However, most of these algorithms assume accurate estimation of the data distribution, which is impractical for high-dimensional variables such as images. On the other hand, modern deep generative architectures can be trained to sample from high-dimensional distributions. However, training these networks are typically very costly. Thus, it is desirable to leverage pre-trained models to answer causal queries using such high-dimensional data. To address this, we propose modular training of deep causal generative models that not only makes learning more efficient, but also allows us to utilize large, pre-trained conditional generative models. To the best of our knowledge, our algorithm, Modular-DCM is the first algorithm that, given the causal structure, uses adversarial training to learn the network weights, and can make use of pre-trained models to provably sample from any identifiable causal query in the presence of latent confounders. With extensive experiments on the Colored-MNIST dataset, we demonstrate that our algorithm outperforms the baselines. We also show our algorithm's convergence on the COVIDx dataset and its utility with a causal invariant prediction problem on CelebA-HQ.

Modular Learning of Deep Causal Generative Models for High-dimensional Causal Inference

TL;DR

The algorithm, Modular-DCM is the first algorithm that, given the causal structure, uses adversarial training to learn the network weights, and can make use of pre-trained models to provably sample from any identifiable causal query in the presence of latent confounders.

Abstract

Sound and complete algorithms have been proposed to compute identifiable causal queries using the causal structure and data. However, most of these algorithms assume accurate estimation of the data distribution, which is impractical for high-dimensional variables such as images. On the other hand, modern deep generative architectures can be trained to sample from high-dimensional distributions. However, training these networks are typically very costly. Thus, it is desirable to leverage pre-trained models to answer causal queries using such high-dimensional data. To address this, we propose modular training of deep causal generative models that not only makes learning more efficient, but also allows us to utilize large, pre-trained conditional generative models. To the best of our knowledge, our algorithm, Modular-DCM is the first algorithm that, given the causal structure, uses adversarial training to learn the network weights, and can make use of pre-trained models to provably sample from any identifiable causal query in the presence of latent confounders. With extensive experiments on the Colored-MNIST dataset, we demonstrate that our algorithm outperforms the baselines. We also show our algorithm's convergence on the COVIDx dataset and its utility with a causal invariant prediction problem on CelebA-HQ.
Paper Structure (44 sections, 17 theorems, 59 equations, 24 figures, 1 table, 10 algorithms)

This paper contains 44 sections, 17 theorems, 59 equations, 24 figures, 1 table, 10 algorithms.

Key Result

Theorem 4.2

Consider any SCM $\mathcal{M}=(G, \mathcal{N}, \mathcal{U}, \mathcal{F}, P(.) )$. A DCM $\mathbb{G}$ for $G$ entails the same identifiable interventional distributions as the SCM $\mathcal{M}$ if it entails the same observational distribution.

Figures (24)

  • Figure 1: Causal graph for the $\mathrm{XrayImg}$ example (top) and its deep causal generative model (bottom). For each variable, an NN ($\mathbb{G}_C,\mathbb{G}_X, \mathbb{G}_N$) is trained to mimic the true mechanism.
  • Figure 2: (Left:) Modular training in 3 steps. (Right:) Causal graphs and their h-graphs showing modularization of the training process.
  • Figure 3: For the frontdoor graph in Figure \ref{['fig:frontdoor-graph']}, NCM produces good images but not consistent with $\text{do}(D)$. Modular-DCM without modular training (DCM-rep) produces consistent but low-quality images. Our modular approach (DCM) with training order: $\{I\}\rightarrow \{D, A\}$ produces consistent, good images and converges faster (as shown in Figure \ref{['fig:digit-benchmarks']}, \ref{['fig:tvd-frontdoor']}, \ref{['fig:fid-frontdoor']}). In Figure \ref{['fig:fid-diamond']}, we show our performance for the graph in Figure \ref{['fig:diamond-graph']} that contains two image variables.
  • Figure 4: Modular DCM on specific and arbitrary graphs.
  • Figure 5: (Top-Left): Invariant prediction causal graph. (Top-right) Images generated by InterFaceGAN from $P(I|Sex,Eyeglass=1)$. (Bottom-left): Joint distribution of $P(Sex,Eyeglass)$. (Bottom-right): Eyeglass prediction accuracy of 3 classifiers in different sub-populations. Three classifiers are trained on the training dataset, the interventional dataset, and the augmented dataset (combined both). Note that the Augmented has better accuracy in the $Sex=0,Eyeglass=1$ sub-population which was our target to achieve.
  • ...and 19 more figures

Theorems & Definitions (47)

  • Definition 3.1: Structural causal model (SCM) pearl2009causality
  • Definition 3.2: c-components
  • Definition 4.1: DCM
  • Theorem 4.2
  • Definition 4.3: $\mathcal{H}$-graph
  • Definition 4.4
  • Theorem 4.5
  • Definition 3.1: Identifiability shpitser2007counterfactuals
  • Definition 3.2: Causal Effects z-Identifiability
  • Theorem 3.3
  • ...and 37 more