Table of Contents
Fetching ...

Perfect Match: A Simple Method for Learning Representations For Counterfactual Inference With Neural Networks

Patrick Schwab, Lorenz Linhardt, Walter Karlen

TL;DR

Perfect Match (PM) introduces a simple, architecture-agnostic minibatch augmentation strategy for counterfactual inference with neural networks. By matching observations within minibatches on balancing scores (e.g., propensity) and augmenting each sample with its near neighbours, PM creates virtually randomized training batches that reduce treatment assignment bias across any number of treatments. The authors extend the TARNET architecture to multi-treatment settings, define multi-treatment evaluation metrics (mPEHE and mATE), and demonstrate that PM outperforms several state-of-the-art methods on IHDP and News datasets, with NN-PEHE serving as a superior model-selection criterion. PM’s simplicity, scalability, and strong empirical performance suggest significant practical impact for causal inference in domains like healthcare and policy, where observational data predominate. The work also highlights the effectiveness of minibatch-level balancing and provides open benchmarks for future research.

Abstract

Learning representations for counterfactual inference from observational data is of high practical relevance for many domains, such as healthcare, public policy and economics. Counterfactual inference enables one to answer "What if...?" questions, such as "What would be the outcome if we gave this patient treatment $t_1$?". However, current methods for training neural networks for counterfactual inference on observational data are either overly complex, limited to settings with only two available treatments, or both. Here, we present Perfect Match (PM), a method for training neural networks for counterfactual inference that is easy to implement, compatible with any architecture, does not add computational complexity or hyperparameters, and extends to any number of treatments. PM is based on the idea of augmenting samples within a minibatch with their propensity-matched nearest neighbours. Our experiments demonstrate that PM outperforms a number of more complex state-of-the-art methods in inferring counterfactual outcomes across several benchmarks, particularly in settings with many treatments.

Perfect Match: A Simple Method for Learning Representations For Counterfactual Inference With Neural Networks

TL;DR

Perfect Match (PM) introduces a simple, architecture-agnostic minibatch augmentation strategy for counterfactual inference with neural networks. By matching observations within minibatches on balancing scores (e.g., propensity) and augmenting each sample with its near neighbours, PM creates virtually randomized training batches that reduce treatment assignment bias across any number of treatments. The authors extend the TARNET architecture to multi-treatment settings, define multi-treatment evaluation metrics (mPEHE and mATE), and demonstrate that PM outperforms several state-of-the-art methods on IHDP and News datasets, with NN-PEHE serving as a superior model-selection criterion. PM’s simplicity, scalability, and strong empirical performance suggest significant practical impact for causal inference in domains like healthcare and policy, where observational data predominate. The work also highlights the effectiveness of minibatch-level balancing and provides open benchmarks for future research.

Abstract

Learning representations for counterfactual inference from observational data is of high practical relevance for many domains, such as healthcare, public policy and economics. Counterfactual inference enables one to answer "What if...?" questions, such as "What would be the outcome if we gave this patient treatment ?". However, current methods for training neural networks for counterfactual inference on observational data are either overly complex, limited to settings with only two available treatments, or both. Here, we present Perfect Match (PM), a method for training neural networks for counterfactual inference that is easy to implement, compatible with any architecture, does not add computational complexity or hyperparameters, and extends to any number of treatments. PM is based on the idea of augmenting samples within a minibatch with their propensity-matched nearest neighbours. Our experiments demonstrate that PM outperforms a number of more complex state-of-the-art methods in inferring counterfactual outcomes across several benchmarks, particularly in settings with many treatments.

Paper Structure

This paper contains 28 sections, 1 theorem, 5 equations, 5 figures, 1 table, 1 algorithm.

Key Result

Theorem 1

Upon convergence, under assumption (1) and for $N\to\infty$, a neural network $\hat{f}$ trained according to the PM algorithm is a consistent estimator of the true potential outcomes $Y$ for each $t$.

Figures (5)

  • Figure 1: The TARNET architecture with $k$ heads for the multiple treatment setting. Each head predicts a potential outcome $\hat{y}_j$, and is only trained on samples that received the respective treatment.
  • Figure 2: Correlation analysis of the real PEHE (y-axis) with the mean squared error (MSE; left) and the nearest neighbour approximation of the precision in estimation of heterogenous effect (NN-PEHE; right) across over 20000 model evaluations on the validation set of IHDP. Scatterplots show a subsample of 1400 data points. $\rho$ indicates the Pearson correlation.
  • Figure 3:
  • Figure 4: Change in error (y-axes) in terms of precision in estimation of heterogenous effect (PEHE) and average treatment effect (ATE) when increasing the percentage of matches in each minibatch (x-axis). Symbols correspond to the mean value of $\hat{\epsilon}_\text{mATE}$ (red) and $\sqrt{\hat{\epsilon}_\text{mPEHE}}$ (blue) on the test set of News-8 across 50 repeated runs with new outcomes (lower is better).
  • Figure 5: Comparison of several state-of-the-art methods for counterfactual inference on the test set of the News-8 dataset when varying the treatment assignment imbalance $\kappa$ (x-axis), i.e. how much the treatment assignment is biased towards more effective treatments. Symbols correspond to the mean value of $\sqrt{\hat{\epsilon}_\text{mPEHE}}$ across 50 repeated runs with new outcomes (lower is better).

Theorems & Definitions (1)

  • Theorem 1