Perfect Match: A Simple Method for Learning Representations For Counterfactual Inference With Neural Networks
Patrick Schwab, Lorenz Linhardt, Walter Karlen
TL;DR
Perfect Match (PM) introduces a simple, architecture-agnostic minibatch augmentation strategy for counterfactual inference with neural networks. By matching observations within minibatches on balancing scores (e.g., propensity) and augmenting each sample with its near neighbours, PM creates virtually randomized training batches that reduce treatment assignment bias across any number of treatments. The authors extend the TARNET architecture to multi-treatment settings, define multi-treatment evaluation metrics (mPEHE and mATE), and demonstrate that PM outperforms several state-of-the-art methods on IHDP and News datasets, with NN-PEHE serving as a superior model-selection criterion. PM’s simplicity, scalability, and strong empirical performance suggest significant practical impact for causal inference in domains like healthcare and policy, where observational data predominate. The work also highlights the effectiveness of minibatch-level balancing and provides open benchmarks for future research.
Abstract
Learning representations for counterfactual inference from observational data is of high practical relevance for many domains, such as healthcare, public policy and economics. Counterfactual inference enables one to answer "What if...?" questions, such as "What would be the outcome if we gave this patient treatment $t_1$?". However, current methods for training neural networks for counterfactual inference on observational data are either overly complex, limited to settings with only two available treatments, or both. Here, we present Perfect Match (PM), a method for training neural networks for counterfactual inference that is easy to implement, compatible with any architecture, does not add computational complexity or hyperparameters, and extends to any number of treatments. PM is based on the idea of augmenting samples within a minibatch with their propensity-matched nearest neighbours. Our experiments demonstrate that PM outperforms a number of more complex state-of-the-art methods in inferring counterfactual outcomes across several benchmarks, particularly in settings with many treatments.
