Table of Contents
Fetching ...

Fast Diffusion-Based Counterfactuals for Shortcut Removal and Generation

Nina Weng, Paraskevas Pegios, Eike Petersen, Aasa Feragen, Siavash Bigdeli

TL;DR

This work addresses the problem of shortcut learning in medical imaging by introducing FastDiME, a diffusion-based counterfactual generation method that removes or adds targeted shortcut features with high efficiency. It combines an efficient gradient estimation strategy that leverages denoised predictions and a self-optimized masking mechanism to localize edits, achieving approximately $20\times$ faster inference while maintaining counterfactual quality and realism. A dedicated pipeline to detect and quantify shortcut learning is presented, using synthetic datasets with varying shortcut correlations and evaluating prediction shifts on shortcut counterfactuals. Empirical results on CelebA and several medical imaging datasets show FastDiME outperforms prior diffusion-based methods in many metrics and approaches the performance of adversarially guided methods like ACE, while offering substantial gains in speed and memory usage. The framework enables robust evaluation and potential mitigation of shortcut reliance in classifiers, with public code to facilitate adoption in practical settings.

Abstract

Shortcut learning is when a model -- e.g. a cardiac disease classifier -- exploits correlations between the target label and a spurious shortcut feature, e.g. a pacemaker, to predict the target label based on the shortcut rather than real discriminative features. This is common in medical imaging, where treatment and clinical annotations correlate with disease labels, making them easy shortcuts to predict disease. We propose a novel detection and quantification of the impact of potential shortcut features via a fast diffusion-based counterfactual image generation that can synthetically remove or add shortcuts. Via a novel inpainting-based modification we spatially limit the changes made with no extra inference step, encouraging the removal of spatially constrained shortcut features while ensuring that the shortcut-free counterfactuals preserve their remaining image features to a high degree. Using these, we assess how shortcut features influence model predictions. This is enabled by our second contribution: An efficient diffusion-based counterfactual explanation method with significant inference speed-up at comparable image quality as state-of-the-art. We confirm this on two large chest X-ray datasets, a skin lesion dataset, and CelebA. Our code is publicly available at fastdime.compute.dtu.dk.

Fast Diffusion-Based Counterfactuals for Shortcut Removal and Generation

TL;DR

This work addresses the problem of shortcut learning in medical imaging by introducing FastDiME, a diffusion-based counterfactual generation method that removes or adds targeted shortcut features with high efficiency. It combines an efficient gradient estimation strategy that leverages denoised predictions and a self-optimized masking mechanism to localize edits, achieving approximately faster inference while maintaining counterfactual quality and realism. A dedicated pipeline to detect and quantify shortcut learning is presented, using synthetic datasets with varying shortcut correlations and evaluating prediction shifts on shortcut counterfactuals. Empirical results on CelebA and several medical imaging datasets show FastDiME outperforms prior diffusion-based methods in many metrics and approaches the performance of adversarially guided methods like ACE, while offering substantial gains in speed and memory usage. The framework enables robust evaluation and potential mitigation of shortcut reliance in classifiers, with public code to facilitate adoption in practical settings.

Abstract

Shortcut learning is when a model -- e.g. a cardiac disease classifier -- exploits correlations between the target label and a spurious shortcut feature, e.g. a pacemaker, to predict the target label based on the shortcut rather than real discriminative features. This is common in medical imaging, where treatment and clinical annotations correlate with disease labels, making them easy shortcuts to predict disease. We propose a novel detection and quantification of the impact of potential shortcut features via a fast diffusion-based counterfactual image generation that can synthetically remove or add shortcuts. Via a novel inpainting-based modification we spatially limit the changes made with no extra inference step, encouraging the removal of spatially constrained shortcut features while ensuring that the shortcut-free counterfactuals preserve their remaining image features to a high degree. Using these, we assess how shortcut features influence model predictions. This is enabled by our second contribution: An efficient diffusion-based counterfactual explanation method with significant inference speed-up at comparable image quality as state-of-the-art. We confirm this on two large chest X-ray datasets, a skin lesion dataset, and CelebA. Our code is publicly available at fastdime.compute.dtu.dk.
Paper Structure (43 sections, 8 equations, 16 figures, 8 tables)

This paper contains 43 sections, 8 equations, 16 figures, 8 tables.

Figures (16)

  • Figure 1: Shortcut detection: SmoothGrad and CF explanation (Left), as two XAI methods, indicate which region in the image could influence the model decision (e.g. from disease to non-disease). Although this includes the shortcut features, it does not clearly indicate it. Therefore an expert is required for further visual inspection. Our counterfactual approach (Right) only removes the desired shortcut attribute. With this we can validate if that specific attribute played a role in the model decision. SmoothGrad visualization: highlighting areas crucial to the model decisions. CF explanation: difference map of the original image and CF (blue/red: information removal/addition).
  • Figure 2: Proposed FastDiME method. In each step, noised image $x_t^c$ is sampled with the guidance of the counterfactual loss, leveraging information derived from the denoised image $\bar{x}_t^c$. A self-optimized mask is automatically extracted and applied to prevent changes in regions less relevant to the task at each time step.
  • Figure 3: Toy 2D example for visualization of the counterfactual generation convergence comparing DiME, GMD, and FastDiME. Given the two class distributions A and B, each method tries to bring the initial point $x_t$ (red point) to class A. Blue vectors indicate the gradients of classification loss $\triangledown L_c$ at each step. When calculating this gradient at each step, DiME uses a new unconditional sample from DDPM, which could lead to noisy gradients. In contrast, FastDiME uses the expected image $\bar{x}_0$ at each step to calculate the gradients, which results in a more stable convergence. The plot on the right shows the convergence of each method in terms of the distance to A. This is averaged over 100 runs and it indicates that FastDiME is accurate and significantly faster than both the other methods.
  • Figure 4: Validating shortcut learning detection pipeline. In order to test the shortcut detecting ability by shortcut-counterfactuals, we construct three synthetic training datasets with varying degrees of correlation between the shortcut feature $s$ and the target label $y$ and train classifiers based on them (Left and Middle). By measuring the difference in confidence leveling between the original image and shortcut-counterfactual, the degree of shortcut learning is examined (Right). If model predictions differ strongly between natural images and shortcut counterfactual images (while leaving the target label unchanged), shortcut learning has occurred.
  • Figure 5: Shortcut counterfactuals for medical datasets. The shortcuts are highlighted with orange circles or boxes in the original images.
  • ...and 11 more figures