Table of Contents
Fetching ...

Semi-Supervised Learning for Deep Causal Generative Models

Yasin Ibrahim, Hermione Warr, Konstantinos Kamnitsas

TL;DR

This work tackles counterfactual reasoning in medical imaging under missing labels by introducing a semi-supervised deep causal generative framework that merges a hierarchical VAE with predictive components for causal variables $y_C$ and $y_E$. It leverages ELBO-based losses for fully labelled, unlabelled, and partially labelled data, and adds counterfactual regularisation via do-interventions to enforce causal consistency, together with an invertible abduction-action-prediction scheme for counterfactual generation. The approach is demonstrated on Colour Morpho-MNIST and MIMIC-CXR, showing improved counterfactual accuracy, robustness to incomplete labels, and the ability to learn causal relationships under data scarcity, while analyzing the independence of cause and mechanism through label availability. A limitation is the assumption of a known DAG; future work may address DAG misspecification and learning causal structure from data, with potential to augment underrepresented populations in clinical datasets.”

Abstract

Developing models that are capable of answering questions of the form "How would x change if y had been z?'" is fundamental to advancing medical image analysis. Training causal generative models that address such counterfactual questions, though, currently requires that all relevant variables have been observed and that the corresponding labels are available in the training data. However, clinical data may not have complete records for all patients and state of the art causal generative models are unable to take full advantage of this. We thus develop, for the first time, a semi-supervised deep causal generative model that exploits the causal relationships between variables to maximise the use of all available data. We explore this in the setting where each sample is either fully labelled or fully unlabelled, as well as the more clinically realistic case of having different labels missing for each sample. We leverage techniques from causal inference to infer missing values and subsequently generate realistic counterfactuals, even for samples with incomplete labels.

Semi-Supervised Learning for Deep Causal Generative Models

TL;DR

This work tackles counterfactual reasoning in medical imaging under missing labels by introducing a semi-supervised deep causal generative framework that merges a hierarchical VAE with predictive components for causal variables and . It leverages ELBO-based losses for fully labelled, unlabelled, and partially labelled data, and adds counterfactual regularisation via do-interventions to enforce causal consistency, together with an invertible abduction-action-prediction scheme for counterfactual generation. The approach is demonstrated on Colour Morpho-MNIST and MIMIC-CXR, showing improved counterfactual accuracy, robustness to incomplete labels, and the ability to learn causal relationships under data scarcity, while analyzing the independence of cause and mechanism through label availability. A limitation is the assumption of a known DAG; future work may address DAG misspecification and learning causal structure from data, with potential to augment underrepresented populations in clinical datasets.”

Abstract

Developing models that are capable of answering questions of the form "How would x change if y had been z?'" is fundamental to advancing medical image analysis. Training causal generative models that address such counterfactual questions, though, currently requires that all relevant variables have been observed and that the corresponding labels are available in the training data. However, clinical data may not have complete records for all patients and state of the art causal generative models are unable to take full advantage of this. We thus develop, for the first time, a semi-supervised deep causal generative model that exploits the causal relationships between variables to maximise the use of all available data. We explore this in the setting where each sample is either fully labelled or fully unlabelled, as well as the more clinically realistic case of having different labels missing for each sample. We leverage techniques from causal inference to infer missing values and subsequently generate realistic counterfactuals, even for samples with incomplete labels.
Paper Structure (13 sections, 9 equations, 5 figures, 2 tables)

This paper contains 13 sections, 9 equations, 5 figures, 2 tables.

Figures (5)

  • Figure 1: Model outline. Green: observed, Grey: latent, Red: predicted, Blue: causal generative, Yellow: decoding. (left) Training; we use the $y$ predictions for decoding unless they are observed, (right) CF generation.
  • Figure 2: (a) DAG for Colour MorphoMNIST, $d$: digit, $f$: foreground (digit) color, $b$: background color, $t$: thickness, $i$: intensity. (b) DAG for MIMIC-CXR, $s$: sex, $a$: age, $d$: disease status, $r$: race. $U$: respective exogenous noise variables.
  • Figure 2: (left) Colour Morpho-MNIST: Cause and mechanism experiment. (right) MIMIC-CXR: For each intervention $do(\cdot)$, the 3 rows correspond to training with 10%, 20%, 30% of variables labelled, all using CF regularisation. Semi-supervision in both settings (SSL, Flexible) outperforms pure supervision (Sup.).
  • Figure 3: Colour Morpho-MNIST: Accuracy of $\text{do}(d=k)$ on random test images for uniformly random $k \in \{0,\dots,9\}\setminus d$ where $d$ is the digit of the original image. For SSL, the $x$-axis represents to the number of fully labelled samples; for Flexible it represents the number of labels for each variable across all the samples. For SSL+Flexible we use 600 randomly allocated labels for each variable in addition to the number of fully labelled samples denoted by the $x$-axis.
  • Figure 4: (a) CF Regularisation on MIMIC-CXR. (b) MIMIC-CXR CFs from model trained on 40% labels. From top-left: (1) original: white, healthy, 20-year-old male, (2) do($\text{age}\!=\!90$), (3) do(diseased), (4) do(asian), (5) do(female), (6) do(all).