Table of Contents
Fetching ...

Explaining 3D Computed Tomography Classifiers with Counterfactuals

Joseph Paul Cohen, Louis Blankemeier, Akshay Chaudhari

TL;DR

This work tackles the difficulty of explaining 3D CT classifiers with counterfactuals by extending the Latent Shift approach to volumetric data. It introduces a slice-based autoencoder (Slice AE) to enable gradient-based CF generation on high-resolution CT volumes while keeping memory use practical, using a 2D encoder to preserve 3D context through slice concatenation. A VQ-GAN trained on over 1.4 million CT slices enables plausible latent-space manipulations, and the method is demonstrated on clinical phenotype prediction and lung segmentation, with qualitative and quantitative validation on multiple public datasets. The approach yields localized, clinically meaningful counterfactuals, improves interpretability of high-stakes medical AI, and is publicly released to support future research and auditing.

Abstract

Counterfactual explanations enhance the interpretability of deep learning models in medical imaging, yet adapting them to 3D CT scans poses challenges due to volumetric complexity and resource demands. We extend the Latent Shift counterfactual generation method from 2D applications to explain 3D computed tomography (CT) scans classifiers. We address the challenges associated with 3D classifiers, such as limited training samples and high memory demands, by implementing a slice-based autoencoder and gradient blocking except for specific chunks of slices. This method leverages a 2D encoder trained on CT slices, which are subsequently combined to maintain 3D context. We demonstrate this technique on two models for clinical phenotype prediction and lung segmentation. Our approach is both memory-efficient and effective for generating interpretable counterfactuals in high-resolution 3D medical imaging.

Explaining 3D Computed Tomography Classifiers with Counterfactuals

TL;DR

This work tackles the difficulty of explaining 3D CT classifiers with counterfactuals by extending the Latent Shift approach to volumetric data. It introduces a slice-based autoencoder (Slice AE) to enable gradient-based CF generation on high-resolution CT volumes while keeping memory use practical, using a 2D encoder to preserve 3D context through slice concatenation. A VQ-GAN trained on over 1.4 million CT slices enables plausible latent-space manipulations, and the method is demonstrated on clinical phenotype prediction and lung segmentation, with qualitative and quantitative validation on multiple public datasets. The approach yields localized, clinically meaningful counterfactuals, improves interpretability of high-stakes medical AI, and is publicly released to support future research and auditing.

Abstract

Counterfactual explanations enhance the interpretability of deep learning models in medical imaging, yet adapting them to 3D CT scans poses challenges due to volumetric complexity and resource demands. We extend the Latent Shift counterfactual generation method from 2D applications to explain 3D computed tomography (CT) scans classifiers. We address the challenges associated with 3D classifiers, such as limited training samples and high memory demands, by implementing a slice-based autoencoder and gradient blocking except for specific chunks of slices. This method leverages a 2D encoder trained on CT slices, which are subsequently combined to maintain 3D context. We demonstrate this technique on two models for clinical phenotype prediction and lung segmentation. Our approach is both memory-efficient and effective for generating interpretable counterfactuals in high-resolution 3D medical imaging.

Paper Structure

This paper contains 13 sections, 1 equation, 5 figures, 1 table.

Figures (5)

  • Figure 1: Illustration of the slice-based autoencoder (Slice AE) approach for generating counterfactuals in 3D CT volumes. Only selected latent representations, $z_1$ and $z_2$, have gradients computed from the classifier output, enabling memory-efficient counterfactual generation while preserving the ability to navigate and modify specific parts of the 3D volume.
  • Figure 2: A) An example CF generated for lung size. The model's predicted segmentation mask for lung is shown below. A red outline tracing the segmentation on the input image is overlaid in the segmentation of the CF image as well, confirming the reduction in predicted lung size. B) A plot showing the sum of pixels predicted as lung as the $\lambda$ is changed during the Latent Shift CF generation process.
  • Figure 3: A: The input and CF slices showing a reduction in plural effusion. The blue dashed line outlines the side of the lung area that is reduced in the CF. B: Localization of CF slices that contribute to changes in the prediction. The volume is processed in chunks of five slices, restricting changes to that region. Slices between 30 and 35 are identified as producing the most change while slices 45-50 (Shown in A) and 55-60 also reduce the prediction. Heatmaps of these changes between the input and CF are shown.
  • Figure 4: A comparison with other explanation methods which use gradient influence to highlight relevant pixels. Slice 45 is visualized from LUNG1-001 of the NSCLC-Radiomics dataset.
  • Figure 5: A) Results varying the chunk size when computing CFs on the examples. Error bars are standard error. B) The change in distribution of model predictions between the positive/negative examples and the CFs of positive examples.