Table of Contents
Fetching ...

Mitigating Shortcut Learning via Feature Disentanglement in Medical Imaging: A Benchmark Study

Sarah Müller, Philipp Berens

TL;DR

This study tackles shortcut learning in medical imaging by benchmarking feature disentanglement methods that separate task-relevant from confounder information in latent spaces. It systematically compares data-centric rebalancing and model-centric latent-space disentanglement approaches, including distance correlation, MI estimation, and MMD, across a controlled toy dataset and two medical imaging datasets with strong confounding. Key findings show that both data-centric and model-centric strategies improve primary-task performance under distribution shifts, with the combination of rebalancing and disentanglement (especially using distance correlation) yielding robust gains and favorable computational efficiency. The work provides practical guidance on designing robust, generalizable medical imaging models and highlights how latent-space analyses reveal disentanglement quality beyond standard AUROC metrics.

Abstract

Although deep learning models in medical imaging often achieve excellent classification performance, they can rely on shortcut learning, exploiting spurious correlations or confounding factors that are not causally related to the target task. This poses risks in clinical settings, where models must generalize across institutions, populations, and acquisition conditions. Feature disentanglement is a promising approach to mitigate shortcut learning by separating task-relevant information from confounder-related features in latent representations. In this study, we systematically evaluated feature disentanglement methods for mitigating shortcuts in medical imaging, including adversarial learning and latent space splitting based on dependence minimization. We assessed classification performance and disentanglement quality using latent space analyses across one artificial and two medical datasets with natural and synthetic confounders. We also examined robustness under varying levels of confounding and compared computational efficiency across methods. We found that shortcut mitigation methods improved classification performance under strong spurious correlations during training. Latent space analyses revealed differences in representation quality not captured by classification metrics, highlighting the strengths and limitations of each method. Model reliance on shortcuts depended on the degree of confounding in the training data. The best-performing models combine data-centric rebalancing with model-centric disentanglement, achieving stronger and more robust shortcut mitigation than rebalancing alone while maintaining similar computational efficiency.

Mitigating Shortcut Learning via Feature Disentanglement in Medical Imaging: A Benchmark Study

TL;DR

This study tackles shortcut learning in medical imaging by benchmarking feature disentanglement methods that separate task-relevant from confounder information in latent spaces. It systematically compares data-centric rebalancing and model-centric latent-space disentanglement approaches, including distance correlation, MI estimation, and MMD, across a controlled toy dataset and two medical imaging datasets with strong confounding. Key findings show that both data-centric and model-centric strategies improve primary-task performance under distribution shifts, with the combination of rebalancing and disentanglement (especially using distance correlation) yielding robust gains and favorable computational efficiency. The work provides practical guidance on designing robust, generalizable medical imaging models and highlights how latent-space analyses reveal disentanglement quality beyond standard AUROC metrics.

Abstract

Although deep learning models in medical imaging often achieve excellent classification performance, they can rely on shortcut learning, exploiting spurious correlations or confounding factors that are not causally related to the target task. This poses risks in clinical settings, where models must generalize across institutions, populations, and acquisition conditions. Feature disentanglement is a promising approach to mitigate shortcut learning by separating task-relevant information from confounder-related features in latent representations. In this study, we systematically evaluated feature disentanglement methods for mitigating shortcuts in medical imaging, including adversarial learning and latent space splitting based on dependence minimization. We assessed classification performance and disentanglement quality using latent space analyses across one artificial and two medical datasets with natural and synthetic confounders. We also examined robustness under varying levels of confounding and compared computational efficiency across methods. We found that shortcut mitigation methods improved classification performance under strong spurious correlations during training. Latent space analyses revealed differences in representation quality not captured by classification metrics, highlighting the strengths and limitations of each method. Model reliance on shortcuts depended on the degree of confounding in the training data. The best-performing models combine data-centric rebalancing with model-centric disentanglement, achieving stronger and more robust shortcut mitigation than rebalancing alone while maintaining similar computational efficiency.
Paper Structure (38 sections, 11 equations, 6 figures, 4 tables)

This paper contains 38 sections, 11 equations, 6 figures, 4 tables.

Figures (6)

  • Figure 1: Overview of shortcut learning and mitigation via feature disentanglement. a Example of a spurious correlation between two binary tasks in the training data that reverses at test time, illustrating a distribution shift. b Causal graph in which a confounder affects both tasks $y_1$ and $y_2$, while the image $X$ is generated from both; the predictive association from $y_2$ to $y_1$ constitutes a shortcut. c Subspace disentanglement architecture that splits the latent representation into task-specific subspaces and minimizes their statistical dependence by minimizing a dependence measure $\mathcal{L}_\text{dep}$. d Comparison of latent representations showing that disentanglement reduces clustering of the confounding label $y_2$ in the target-task subspace $z_1$.
  • Figure 2: Overview of label distributions in Morpho-MNIST, CheXpert, and OCT. a shows example images sampled for each label combination, b shows contingency tables of the original training data, and c shows contingency tables of the sub-sampled training data actually used. In the final training data (c), strong correlations were induced between the primary task and confounder for all datasets, while maintaining balanced distributions within each classification task.
  • Figure 3: Test data distributions for evaluating shortcut mitigation, showing original (a), balanced (b), and inverted (c) correlations between the primary task and confounder.
  • Figure 4: Qualitative scatter plots showing the two-dimensional subspace $z_1$ of the first fold fold, highlighted by the binary labels $y_2$. Feature disentanglement was successful if there is no structure visible with respect to $y_2$.
  • Figure 5: Relative AUROC improvement over the Baseline on the inverted test distribution ($p(y_1=1 \mid y_2=1)=5\%$) across different conditional prevalence levels used during training. Positive values indicate gains relative to the Baseline; variants combined with Rebalancing are shown with dashed lines.
  • ...and 1 more figures