A Partial Initialization Strategy to Mitigate the Overfitting Problem in CATE Estimation with Hidden Confounding
Chuan Zhou, Yaxuan Li, Chunyuan Zheng, Haiteng Zhang, Haoxuan Li, Mingming Gong
TL;DR
This work tackles CATE estimation under hidden confounding by fusing large-scale observational data with small-scale RCT data. It introduces a two-stage pretraining-finetuning framework (TSPF) that first learns a foundational covariate representation from OBS data and then adapts this representation using a representation adapter and partially initialized heads trained on RCT data to mitigate hidden bias. Key components include covariate-balancing representation learning with IPM, a reconstruction objective, a mutual-information regularizer via CLUB, and a carefully designed partial initialization strategy to prevent overfitting. Across semi-synthetic IHDP and Jobs datasets, the proposed method outperforms baselines and demonstrates robust CATE estimation without strict linear or additive generative assumptions, highlighting its practical potential for data fusion in causal inference.
Abstract
Estimating the conditional average treatment effect (CATE) from observational data plays a crucial role in areas such as e-commerce, healthcare, and economics. Existing studies mainly rely on the strong ignorability assumption that there are no hidden confounders, whose existence cannot be tested from observational data and can invalidate any causal conclusion. In contrast, data collected from randomized controlled trials (RCT) do not suffer from confounding but are usually limited by a small sample size. To avoid overfitting caused by the small-scale RCT data, we propose a novel two-stage pretraining-finetuning (TSPF) framework with a partial parameter initialization strategy to estimate the CATE in the presence of hidden confounding. In the first stage, a foundational representation of covariates is trained to estimate counterfactual outcomes through large-scale observational data. In the second stage, we propose to train an augmented representation of the covariates, which is concatenated with the foundational representation obtained in the first stage to adjust for the hidden confounding. Rather than training a separate network from scratch, part of the prediction heads are initialized from the first stage. The superiority of our approach is validated on two datasets with extensive experiments.
