Table of Contents
Fetching ...

Preventing Shortcut Learning in Medical Image Analysis through Intermediate Layer Knowledge Distillation from Specialist Teachers

Christopher Boland, Sotirios Tsaftaris, Sonia Dahdouh

TL;DR

The paper tackles shortcut learning in medical image analysis by introducing intermediate-layer knowledge distillation from a specialist, bias-free teacher to a student trained on biased data. By applying sample-level KL divergence across multiple intermediate probes and an additional final-layer distillation term, the method guides learning toward clinically relevant features, improving generalization on both in-distribution and out-of-distribution data. Across CheXpert, ISIC 2017, and SimBA with CNN and 3D-CNN architectures, the approach yields consistent reductions in bias (Delta TPR) and competitive AUC, often approaching the performance of models trained on bias-free data. The work also demonstrates practical benefits, such as effectiveness with small or out-of-distribution teacher data, partial-layer distillation, and the use of compact teachers, offering a viable path for deploying robust medical AI under limited bias annotations.

Abstract

Deep learning models are prone to learning shortcut solutions to problems using spuriously correlated yet irrelevant features of their training data. In high-risk applications such as medical image analysis, this phenomenon may prevent models from using clinically meaningful features when making predictions, potentially leading to poor robustness and harm to patients. We demonstrate that different types of shortcuts (those that are diffuse and spread throughout the image, as well as those that are localized to specific areas) manifest distinctly across network layers and can, therefore, be more effectively targeted through mitigation strategies that target the intermediate layers. We propose a novel knowledge distillation framework that leverages a teacher network fine-tuned on a small subset of task-relevant data to mitigate shortcut learning in a student network trained on a large dataset corrupted with a bias feature. Through extensive experiments on CheXpert, ISIC 2017, and SimBA datasets using various architectures (ResNet-18, AlexNet, DenseNet-121, and 3D CNNs), we demonstrate consistent improvements over traditional Empirical Risk Minimization, augmentation-based bias-mitigation, and group-based bias-mitigation approaches. In many cases, we achieve comparable performance with a baseline model trained on bias-free data, even on out-of-distribution test data. Our results demonstrate the practical applicability of our approach to real-world medical imaging scenarios where bias annotations are limited and shortcut features are difficult to identify a priori.

Preventing Shortcut Learning in Medical Image Analysis through Intermediate Layer Knowledge Distillation from Specialist Teachers

TL;DR

The paper tackles shortcut learning in medical image analysis by introducing intermediate-layer knowledge distillation from a specialist, bias-free teacher to a student trained on biased data. By applying sample-level KL divergence across multiple intermediate probes and an additional final-layer distillation term, the method guides learning toward clinically relevant features, improving generalization on both in-distribution and out-of-distribution data. Across CheXpert, ISIC 2017, and SimBA with CNN and 3D-CNN architectures, the approach yields consistent reductions in bias (Delta TPR) and competitive AUC, often approaching the performance of models trained on bias-free data. The work also demonstrates practical benefits, such as effectiveness with small or out-of-distribution teacher data, partial-layer distillation, and the use of compact teachers, offering a viable path for deploying robust medical AI under limited bias annotations.

Abstract

Deep learning models are prone to learning shortcut solutions to problems using spuriously correlated yet irrelevant features of their training data. In high-risk applications such as medical image analysis, this phenomenon may prevent models from using clinically meaningful features when making predictions, potentially leading to poor robustness and harm to patients. We demonstrate that different types of shortcuts (those that are diffuse and spread throughout the image, as well as those that are localized to specific areas) manifest distinctly across network layers and can, therefore, be more effectively targeted through mitigation strategies that target the intermediate layers. We propose a novel knowledge distillation framework that leverages a teacher network fine-tuned on a small subset of task-relevant data to mitigate shortcut learning in a student network trained on a large dataset corrupted with a bias feature. Through extensive experiments on CheXpert, ISIC 2017, and SimBA datasets using various architectures (ResNet-18, AlexNet, DenseNet-121, and 3D CNNs), we demonstrate consistent improvements over traditional Empirical Risk Minimization, augmentation-based bias-mitigation, and group-based bias-mitigation approaches. In many cases, we achieve comparable performance with a baseline model trained on bias-free data, even on out-of-distribution test data. Our results demonstrate the practical applicability of our approach to real-world medical imaging scenarios where bias annotations are limited and shortcut features are difficult to identify a priori.

Paper Structure

This paper contains 38 sections, 3 equations, 15 figures, 9 tables.

Figures (15)

  • Figure 1: Overview of the proposed student-teacher training method. The teacher network, trained on clean data, guides the student model's learning process through the distillation of task-specific knowledge to the intermediate layers.
  • Figure 2: ISIC skin lesion image augmented with synthetic shortcuts: (a) original; (b) noise; (c) square (constant location); (d) square (random location). The noise effect has been amplified here for illustrative purposes.
  • Figure 3: Illustrative representation of the synthetic shortcut feature distribution in our train, validation, and test splits in the CheXpert and ISIC datasets.
  • Figure 4: Intermediate-layer confidence of two ResNet-18 models trained on near-identical training sets. Confidence bands represent the standard deviation over 5-fold cross-validation. Both networks are trained on the CheXpert dataset with a learning rate of $1e^{-4}$. Intermediate layer classification probes have a learning rate of $0.1$. The training data of one model has been corrupted with various synthetic shortcut features, while the training data of the other has not.
  • Figure 5: Performance of ResNet-18 trained on ISIC and CheXpert datasets featuring multiple simultaneous shortcuts. The green line represents a model trained on a training set before augmenting with synthetic shortcuts. We compare our student with a specialized teacher to JTT and GroupDRO.
  • ...and 10 more figures