Adapting to Shifting Correlations with Unlabeled Data Calibration
Minh Nguyen, Alan Q. Wang, Heejong Kim, Mert R. Sabuncu
TL;DR
Generalized Prevalence Adjustment (GPA) tackles distribution shifts across sites by adaptively using unstable confounder information without requiring labels or confounder data at test time. It extends CoPA with an EM-based, unlabeled-data approach to estimate the conditional prevalence $P_b(Y|\mathbf{Z})$ at new sites and a dropout-like input knockout to handle missing $\mathbf{Z}$, leveraging two networks $f_\theta$ and $g_\phi$ and a calibration step. Empirical results on Color MNIST and real-world medical imaging datasets (ISIC, CXR) show GPA often outperforms standard baselines and CoPA, including under scenarios where $\mathbf{Z}$ is partially missing. The work contributes a scalable, data-efficient method for improving OOD generalization in settings with shifting $Y|\mathbf{Z}$, with practical implications for healthcare imaging where site metadata can be incomplete or unstable.
Abstract
Distribution shifts between sites can seriously degrade model performance since models are prone to exploiting unstable correlations. Thus, many methods try to find features that are stable across sites and discard unstable features. However, unstable features might have complementary information that, if used appropriately, could increase accuracy. More recent methods try to adapt to unstable features at the new sites to achieve higher accuracy. However, they make unrealistic assumptions or fail to scale to multiple confounding features. We propose Generalized Prevalence Adjustment (GPA for short), a flexible method that adjusts model predictions to the shifting correlations between prediction target and confounders to safely exploit unstable features. GPA can infer the interaction between target and confounders in new sites using unlabeled samples from those sites. We evaluate GPA on several real and synthetic datasets, and show that it outperforms competitive baselines.
