Table of Contents
Fetching ...

Considerations for Distribution Shift Robustness of Diagnostic Models in Healthcare

Arno Blaas, Adam Goliński, Andrew Miller, Luca Zappella, Jörn-Henrik Jacobsen, Christina Heinze-Deml

TL;DR

It is theoretically show why ignoring covariates as well as common invariant learning approaches will in general not yield robust predictors in the studied setting, while including certain covariates into the prediction model will.

Abstract

We consider robustness to distribution shifts in the context of diagnostic models in healthcare, where the prediction target $Y$, e.g., the presence of a disease, is causally upstream of the observations $X$, e.g., a biomarker. Distribution shifts may occur, for instance, when the training data is collected in a domain with patients having particular demographic characteristics while the model is deployed on patients from a different demographic group. In the domain of applied ML for health, it is common to predict $Y$ from $X$ without considering further information about the patient. However, beyond the direct influence of the disease $Y$ on biomarker $X$, a predictive model may learn to exploit confounding dependencies (or shortcuts) between $X$ and $Y$ that are unstable under certain distribution shifts. In this work, we highlight a data generating mechanism common to healthcare settings and discuss how recent theoretical results from the causality literature can be applied to build robust predictive models. We theoretically show why ignoring covariates as well as common invariant learning approaches will in general not yield robust predictors in the studied setting, while including certain covariates into the prediction model will. In an extensive simulation study, we showcase the robustness (or lack thereof) of different predictors under various data generating processes. Lastly, we analyze the performance of the different approaches using the PTB-XL dataset, a public dataset of annotated ECG recordings.

Considerations for Distribution Shift Robustness of Diagnostic Models in Healthcare

TL;DR

It is theoretically show why ignoring covariates as well as common invariant learning approaches will in general not yield robust predictors in the studied setting, while including certain covariates into the prediction model will.

Abstract

We consider robustness to distribution shifts in the context of diagnostic models in healthcare, where the prediction target , e.g., the presence of a disease, is causally upstream of the observations , e.g., a biomarker. Distribution shifts may occur, for instance, when the training data is collected in a domain with patients having particular demographic characteristics while the model is deployed on patients from a different demographic group. In the domain of applied ML for health, it is common to predict from without considering further information about the patient. However, beyond the direct influence of the disease on biomarker , a predictive model may learn to exploit confounding dependencies (or shortcuts) between and that are unstable under certain distribution shifts. In this work, we highlight a data generating mechanism common to healthcare settings and discuss how recent theoretical results from the causality literature can be applied to build robust predictive models. We theoretically show why ignoring covariates as well as common invariant learning approaches will in general not yield robust predictors in the studied setting, while including certain covariates into the prediction model will. In an extensive simulation study, we showcase the robustness (or lack thereof) of different predictors under various data generating processes. Lastly, we analyze the performance of the different approaches using the PTB-XL dataset, a public dataset of annotated ECG recordings.

Paper Structure

This paper contains 22 sections, 1 theorem, 8 equations, 10 figures.

Key Result

Proposition 1

For any element $P_t \in \mathcal{P}_{cause}$ as defined in eq:cause, it holds that $P_t(Y|X,V) = P_s(Y|X,V)$. Furthermore, for such a $P_t$, in general $P_t(Y|X) \neq P_s(Y|X)$ as well as $P_t(Y) \neq P_s(Y)$ .

Figures (10)

  • Figure 1: Data generating processes considered in this work. $I_V$ is an intervention variable, which describes the assumed distribution shift. (a) The "spurious relation process" features a spurious relation between $Y$ and $V$, through a confounding variable $C$, and has been considered in the ML literature HeinzeDeml21Veitch21makar22apuli22. This setting requires that the marginal $P(Y)$ remains invariant across distribution shifts. (b) In the "causal relation process", the shortcut variable $V$ is a direct cause of the outcome $Y$, shifting the marginal $P(Y)$ when $I_V$ shifts the marginal $P(V)$.
  • Figure 2: The four likelihood conditionals $P(X|Y,V)$, one for each combination of $Y, V$, for all three scenarios.
  • Figure 3: Scenario 1: AUC (left) and accuracy (right) of the models $P_s(Y | X)$ and $P_s(Y | X, V)$, the oracle predictors $P_t(Y | X)$ and $P_t(Y | X, V)$, and the invariant predictor $P_m(Y | X)$ of makar22a, as a function of the target marginal $P_t(V) = p$, with source $P_s(V) = 0.4$ (grey dashed vertical line). Note that $P_s(Y | X, V)$ and $P_t(Y | X, V)$ almost perfectly overlap, up to the stochasticity of the estimation procedure. For $P_m(Y | X)$, we present the mean and std. dev. of the best hyperparameter setting according to the area under the accuracy-vs-$P_t(V)$ curve over 100 training runs. See the main text for details.
  • Figure 4: Scenario 2: AUC (left) and accuracy (right) of the models $P_s(Y | X)$ and $P_s(Y | X, V)$, the oracle predictors $P_t(Y | X)$ and $P_t(Y | X, V)$, and the invariant predictor $P_m(Y | X)$ of makar22a, as a function of the target marginal $P_t(V) = p$, with source $P_s(V) = 0.1$ (grey dashed vertical line). For $P_m(Y | X)$, we present the mean and std. dev. of the best hyperparameter setting according to the area under the accuracy-vs-$P_t(V)$ curve over 100 training runs. See the main text for details.
  • Figure 5: Scenario 3: AUC (left) and accuracy (right) of the models $P_s(Y | X)$ and $P_s(Y | X, V)$, the oracle predictors $P_t(Y | X)$ and $P_t(Y | X, V)$, and the invariant predictor $P_m(Y | X)$ of makar22a, as a function of the distribution shift under $\mathcal{P}_{spur}$: $P_t(V=1 | Y=1) = P_t(V=0 | Y=0) = p$, with source $P_s(V=1 | Y=1) = P_s(V=0 | Y=0) = 0.2$ (grey dashed vertical line). For $P_m(Y | X)$, we present the mean and std. dev. of the best hyperparameter setting according to the area under the accuracy-vs-$P_t(V)$ curve over 100 training runs. See the main text for details.
  • ...and 5 more figures

Theorems & Definitions (2)

  • Proposition 1
  • Remark 1