Table of Contents
Fetching ...

Deep Learning for Causal Inference: A Comparison of Architectures for Heterogeneous Treatment Effect Estimation

Demetrios Papakostas, Andrew Herren, P. Richard Hahn, Francisco Castillo

TL;DR

The work tackles estimating heterogeneous treatment effects under a binary treatment using deep neural networks that model $Y$ as $Y = α(X) + β(X)Z$. It compares three neural architectures—Farrell's joint-training shared network, BCF-nnet with separate networks, and a naive two-network regression—against a linear baseline, highlighting the role of the propensity function $π(X)$ and the prognostic function $α(X)$. In simulations, the nnet-BCF approach often outperforms the alternatives when the treatment effect $β(X)$ is small relative to the prognostic signal $α(X)$, while the shared-network can excel with strong treatment signals and sufficient data; gaps shrink as sample size grows. A real-data example on stress and sleep demonstrates practical applicability and adaptability to multimodal data, with a PyTorch-based implementation enabling flexible regularization and scalability for future extensions.

Abstract

Causal inference has gained much popularity in recent years, with interests ranging from academic, to industrial, to educational, and all in between. Concurrently, the study and usage of neural networks has also grown profoundly (albeit at a far faster rate). What we aim to do in this blog write-up is demonstrate a Neural Network causal inference architecture. We develop a fully connected neural network implementation of the popular Bayesian Causal Forest algorithm, a state of the art tree based method for estimating heterogeneous treatment effects. We compare our implementation to existing neural network causal inference methodologies, showing improvements in performance in simulation settings. We apply our method to a dataset examining the effect of stress on sleep.

Deep Learning for Causal Inference: A Comparison of Architectures for Heterogeneous Treatment Effect Estimation

TL;DR

The work tackles estimating heterogeneous treatment effects under a binary treatment using deep neural networks that model as . It compares three neural architectures—Farrell's joint-training shared network, BCF-nnet with separate networks, and a naive two-network regression—against a linear baseline, highlighting the role of the propensity function and the prognostic function . In simulations, the nnet-BCF approach often outperforms the alternatives when the treatment effect is small relative to the prognostic signal , while the shared-network can excel with strong treatment signals and sufficient data; gaps shrink as sample size grows. A real-data example on stress and sleep demonstrates practical applicability and adaptability to multimodal data, with a PyTorch-based implementation enabling flexible regularization and scalability for future extensions.

Abstract

Causal inference has gained much popularity in recent years, with interests ranging from academic, to industrial, to educational, and all in between. Concurrently, the study and usage of neural networks has also grown profoundly (albeit at a far faster rate). What we aim to do in this blog write-up is demonstrate a Neural Network causal inference architecture. We develop a fully connected neural network implementation of the popular Bayesian Causal Forest algorithm, a state of the art tree based method for estimating heterogeneous treatment effects. We compare our implementation to existing neural network causal inference methodologies, showing improvements in performance in simulation settings. We apply our method to a dataset examining the effect of stress on sleep.
Paper Structure (12 sections, 7 equations, 7 figures, 2 tables)

This paper contains 12 sections, 7 equations, 7 figures, 2 tables.

Figures (7)

  • Figure 1: The Farrell method with a 3-dimensional vector of covariates $X$, 4 nodes in each hidden layer (in practice, these layers are usually much deeper). $G$ is an activation function that takes $\alpha(X)+\beta(X)Z$ as an argument.
  • Figure 2: The BCF nnet architecture, where $G(\cdot)$ is an activation function that takes $\alpha(X)+\beta(X)Z$ as an argument.
  • Figure 3: Left panel: Histogram of $\beta$. On the right is a plot of $\alpha$ vs $\pi$, indicative of strong targeted selection. For this particular realization of \ref{['eq:dgp1']}, with $n=10,000$, the mean of $\beta(X)=0.20$, the mean of $\alpha(X)=1.95$, and the range of $\pi(X)=\qty(0.11, 0.90)$, with mean of 0.37.
  • Figure 4: Left panel: Bias of dgp with different $n$. Right: RMSE. This is in the "small" $\beta$ world.
  • Figure 5: Comparing Individual biases and rmse's across the 100 Monte Carlo runs for the shared and BCF architectures.
  • ...and 2 more figures