Deep Learning for Causal Inference: A Comparison of Architectures for Heterogeneous Treatment Effect Estimation
Demetrios Papakostas, Andrew Herren, P. Richard Hahn, Francisco Castillo
TL;DR
The work tackles estimating heterogeneous treatment effects under a binary treatment using deep neural networks that model $Y$ as $Y = α(X) + β(X)Z$. It compares three neural architectures—Farrell's joint-training shared network, BCF-nnet with separate networks, and a naive two-network regression—against a linear baseline, highlighting the role of the propensity function $π(X)$ and the prognostic function $α(X)$. In simulations, the nnet-BCF approach often outperforms the alternatives when the treatment effect $β(X)$ is small relative to the prognostic signal $α(X)$, while the shared-network can excel with strong treatment signals and sufficient data; gaps shrink as sample size grows. A real-data example on stress and sleep demonstrates practical applicability and adaptability to multimodal data, with a PyTorch-based implementation enabling flexible regularization and scalability for future extensions.
Abstract
Causal inference has gained much popularity in recent years, with interests ranging from academic, to industrial, to educational, and all in between. Concurrently, the study and usage of neural networks has also grown profoundly (albeit at a far faster rate). What we aim to do in this blog write-up is demonstrate a Neural Network causal inference architecture. We develop a fully connected neural network implementation of the popular Bayesian Causal Forest algorithm, a state of the art tree based method for estimating heterogeneous treatment effects. We compare our implementation to existing neural network causal inference methodologies, showing improvements in performance in simulation settings. We apply our method to a dataset examining the effect of stress on sleep.
