Table of Contents
Fetching ...

De-confounding Representation Learning for Counterfactual Inference on Continuous Treatment via Generative Adversarial Network

Yonghe Zhao, Qiang Huang, Haolong Zeng, Yun Pen, Huiyan Sun

TL;DR

The paper tackles counterfactual inference for continuous treatments under confounding by proposing De-confounding Representation Learning (DRL), a nonparametric adversarial framework that learns covariate representations $X^{G}$ disentangled from the treatment $t$ while preserving counterfactual predictive power. DRL combines an adversarial Correlation network and a Discriminator with a CounterFactual module, leveraging random virtual representations $X^{R}$ and learned $X^{G}$ to minimize both linear and nonlinear correlations with $t$. Extensive synthetic experiments show DRL outperforms state-of-the-art methods in de-confounding and counterfactual accuracy, particularly under nonlinear relationships, and a real-world MIMIC III case demonstrates a causal link between red cell distribution width and mortality. The approach offers scalable, nonparametric counterfactual inference for continuous treatments with practical implications for precision medicine and policy evaluation.

Abstract

Counterfactual inference for continuous rather than binary treatment variables is more common in real-world causal inference tasks. While there are already some sample reweighting methods based on Marginal Structural Model for eliminating the confounding bias, they generally focus on removing the treatment's linear dependence on confounders and rely on the accuracy of the assumed parametric models, which are usually unverifiable. In this paper, we propose a de-confounding representation learning (DRL) framework for counterfactual outcome estimation of continuous treatment by generating the representations of covariates disentangled with the treatment variables. The DRL is a non-parametric model that eliminates both linear and nonlinear dependence between treatment and covariates. Specifically, we train the correlations between the de-confounded representations and the treatment variables against the correlations between the covariate representations and the treatment variables to eliminate confounding bias. Further, a counterfactual inference network is embedded into the framework to make the learned representations serve both de-confounding and trusted inference. Extensive experiments on synthetic datasets show that the DRL model performs superiorly in learning de-confounding representations and outperforms state-of-the-art counterfactual inference models for continuous treatment variables. In addition, we apply the DRL model to a real-world medical dataset MIMIC and demonstrate a detailed causal relationship between red cell width distribution and mortality.

De-confounding Representation Learning for Counterfactual Inference on Continuous Treatment via Generative Adversarial Network

TL;DR

The paper tackles counterfactual inference for continuous treatments under confounding by proposing De-confounding Representation Learning (DRL), a nonparametric adversarial framework that learns covariate representations disentangled from the treatment while preserving counterfactual predictive power. DRL combines an adversarial Correlation network and a Discriminator with a CounterFactual module, leveraging random virtual representations and learned to minimize both linear and nonlinear correlations with . Extensive synthetic experiments show DRL outperforms state-of-the-art methods in de-confounding and counterfactual accuracy, particularly under nonlinear relationships, and a real-world MIMIC III case demonstrates a causal link between red cell distribution width and mortality. The approach offers scalable, nonparametric counterfactual inference for continuous treatments with practical implications for precision medicine and policy evaluation.

Abstract

Counterfactual inference for continuous rather than binary treatment variables is more common in real-world causal inference tasks. While there are already some sample reweighting methods based on Marginal Structural Model for eliminating the confounding bias, they generally focus on removing the treatment's linear dependence on confounders and rely on the accuracy of the assumed parametric models, which are usually unverifiable. In this paper, we propose a de-confounding representation learning (DRL) framework for counterfactual outcome estimation of continuous treatment by generating the representations of covariates disentangled with the treatment variables. The DRL is a non-parametric model that eliminates both linear and nonlinear dependence between treatment and covariates. Specifically, we train the correlations between the de-confounded representations and the treatment variables against the correlations between the covariate representations and the treatment variables to eliminate confounding bias. Further, a counterfactual inference network is embedded into the framework to make the learned representations serve both de-confounding and trusted inference. Extensive experiments on synthetic datasets show that the DRL model performs superiorly in learning de-confounding representations and outperforms state-of-the-art counterfactual inference models for continuous treatment variables. In addition, we apply the DRL model to a real-world medical dataset MIMIC and demonstrate a detailed causal relationship between red cell width distribution and mortality.
Paper Structure (21 sections, 17 equations, 4 figures, 2 tables)

This paper contains 21 sections, 17 equations, 4 figures, 2 tables.

Figures (4)

  • Figure 1: The issues engendered by confounding covariates are twofold: (i) inconsistent distribution of $X$ amidst discrete $t$ values; (ii) a distributive interdependence between $X$ and continuous $t$ values.
  • Figure 2: The overview framework of the de-confounding representation learning that contains four sub-modules: Generator, Discriminator, Correlation network and CounterFactual module.
  • Figure 3: Comparison of the variation of correlation between treatment variables and covariates.
  • Figure 4: The average MTEF of RDW on the mortality within one year of hospital discharge.