Learning Counterfactually Invariant Predictors
Francesco Quinzan, Cecilia Casolo, Krikamol Muandet, Yucen Luo, Niki Kilbertus
TL;DR
The paper tackles learning predictors that are invariant to counterfactual changes in the data-generating process by casting CI as a conditional independence constraint in the observational distribution, under a graphical injectivity assumption. It introduces Counterfactually Invariant Prediction (CIP), a model-agnostic framework that uses the Hilbert-Schmidt Conditional Independence Criterion (HSCIC) to enforce CI while allowing mixed data types; CI is optimized via a tunable parameter $oldsymbol{ extgamma}$ that trades accuracy for invariance. The authors provide theoretical support linking CI to conditional independence, describe practical estimation of HSCIC from samples, and validate CIP on synthetic and real datasets (including dSprites and UCI Adult) showing favorable MSE/VCF trade-offs and improved counterfactual robustness. They also discuss computational considerations, limitations (notably the need for a known causal graph and scalable CI estimation), and potential extensions to causal representation learning and broader CI notions. Overall, CIP offers a principled, kernel-based route to robust, fair, and generalizable predictors under counterfactual shifts.
Abstract
Notions of counterfactual invariance (CI) have proven essential for predictors that are fair, robust, and generalizable in the real world. We propose graphical criteria that yield a sufficient condition for a predictor to be counterfactually invariant in terms of a conditional independence in the observational distribution. In order to learn such predictors, we propose a model-agnostic framework, called Counterfactually Invariant Prediction (CIP), building on the Hilbert-Schmidt Conditional Independence Criterion (HSCIC), a kernel-based conditional dependence measure. Our experimental results demonstrate the effectiveness of CIP in enforcing counterfactual invariance across various simulated and real-world datasets including scalar and multi-variate settings.
