Table of Contents
Fetching ...

Adapting to Shifting Correlations with Unlabeled Data Calibration

Minh Nguyen, Alan Q. Wang, Heejong Kim, Mert R. Sabuncu

TL;DR

Generalized Prevalence Adjustment (GPA) tackles distribution shifts across sites by adaptively using unstable confounder information without requiring labels or confounder data at test time. It extends CoPA with an EM-based, unlabeled-data approach to estimate the conditional prevalence $P_b(Y|\mathbf{Z})$ at new sites and a dropout-like input knockout to handle missing $\mathbf{Z}$, leveraging two networks $f_\theta$ and $g_\phi$ and a calibration step. Empirical results on Color MNIST and real-world medical imaging datasets (ISIC, CXR) show GPA often outperforms standard baselines and CoPA, including under scenarios where $\mathbf{Z}$ is partially missing. The work contributes a scalable, data-efficient method for improving OOD generalization in settings with shifting $Y|\mathbf{Z}$, with practical implications for healthcare imaging where site metadata can be incomplete or unstable.

Abstract

Distribution shifts between sites can seriously degrade model performance since models are prone to exploiting unstable correlations. Thus, many methods try to find features that are stable across sites and discard unstable features. However, unstable features might have complementary information that, if used appropriately, could increase accuracy. More recent methods try to adapt to unstable features at the new sites to achieve higher accuracy. However, they make unrealistic assumptions or fail to scale to multiple confounding features. We propose Generalized Prevalence Adjustment (GPA for short), a flexible method that adjusts model predictions to the shifting correlations between prediction target and confounders to safely exploit unstable features. GPA can infer the interaction between target and confounders in new sites using unlabeled samples from those sites. We evaluate GPA on several real and synthetic datasets, and show that it outperforms competitive baselines.

Adapting to Shifting Correlations with Unlabeled Data Calibration

TL;DR

Generalized Prevalence Adjustment (GPA) tackles distribution shifts across sites by adaptively using unstable confounder information without requiring labels or confounder data at test time. It extends CoPA with an EM-based, unlabeled-data approach to estimate the conditional prevalence at new sites and a dropout-like input knockout to handle missing , leveraging two networks and and a calibration step. Empirical results on Color MNIST and real-world medical imaging datasets (ISIC, CXR) show GPA often outperforms standard baselines and CoPA, including under scenarios where is partially missing. The work contributes a scalable, data-efficient method for improving OOD generalization in settings with shifting , with practical implications for healthcare imaging where site metadata can be incomplete or unstable.

Abstract

Distribution shifts between sites can seriously degrade model performance since models are prone to exploiting unstable correlations. Thus, many methods try to find features that are stable across sites and discard unstable features. However, unstable features might have complementary information that, if used appropriately, could increase accuracy. More recent methods try to adapt to unstable features at the new sites to achieve higher accuracy. However, they make unrealistic assumptions or fail to scale to multiple confounding features. We propose Generalized Prevalence Adjustment (GPA for short), a flexible method that adjusts model predictions to the shifting correlations between prediction target and confounders to safely exploit unstable features. GPA can infer the interaction between target and confounders in new sites using unlabeled samples from those sites. We evaluate GPA on several real and synthetic datasets, and show that it outperforms competitive baselines.
Paper Structure (25 sections, 16 equations, 10 figures, 5 tables, 3 algorithms)

This paper contains 25 sections, 16 equations, 10 figures, 5 tables, 3 algorithms.

Figures (10)

  • Figure 1: Left: The generative distribution of image $\mathbf{X}$ (i.e., $P(\mathbf{X}|Y,\mathbf{Z})$) is stable at different sites while the correlations between prediction target $Y$ and confounding variables $\mathbf{Z}$ can vary. Red edges are unstable (i.e. generative mechanisms vary with sites) while black edges are stable. Right: The generative mechanism inspires our two-part modeling approach, where one model (i.e., $f(\mathbf{X},\mathbf{Z})$) learns the stable mechanism, while another model learns the mechanism that varies with sites (i.e., $g(\mathbf{Z})$). At a new site, only $g(\mathbf{Z})$ needs to be estimated from unlabeled data, while $f(\mathbf{X},\mathbf{Z})$ can be reused.
  • Figure 2: F1-score at test site. GPA outperforms all the baseline approaches.
  • Figure 3: Color MNIST experiment ablation. F1-scores are shown. $\cdot ^*$ indicates that $\mathbf{Z}$ is missing at test site for that method.
  • Figure 4: F1-score at test sites in ISIC experiment. GPA is on par or better than the baseline methods.
  • Figure 5: ISIC experiment ablation. F1 scores are shown. $\cdot ^*$ indicates that $\mathbf{Z}$ is missing at test site for that method.
  • ...and 5 more figures