Table of Contents
Fetching ...

Unmasking Dementia Detection by Masking Input Gradients: A JSM Approach to Model Interpretability and Precision

Yasmine Mustafa, Tie Luo

TL;DR

This work tackles trustworthy medical AI for dementia diagnosis by integrating Jacobian Saliency Maps (JSM) as a self-debugging signal during training through a Jacobian-Augmented Loss (JAL). The approach uses a modality-agnostic deformation-based saliency derived from image registration to emphasize disease-relevant brain volume changes, mitigating Clever Hans behavior. Evaluated on the multimodal OASIS-3 dataset with MRI and CT, the method employs early and late fusion in a lightweight 3D CNN and demonstrates substantial improvements in four-class dementia staging, with ablations confirming JAL’s contribution and visual evidence that gradients align with JSM-indicated deformations. The results indicate improved accuracy (up to ~10%), enhanced interpretability, and the potential for more reliable, clinically usable AI in neurodegenerative disease diagnosis.

Abstract

The evolution of deep learning and artificial intelligence has significantly reshaped technological landscapes. However, their effective application in crucial sectors such as medicine demands more than just superior performance, but trustworthiness as well. While interpretability plays a pivotal role, existing explainable AI (XAI) approaches often do not reveal {\em Clever Hans} behavior where a model makes (ungeneralizable) correct predictions using spurious correlations or biases in data. Likewise, current post-hoc XAI methods are susceptible to generating unjustified counterfactual examples. In this paper, we approach XAI with an innovative {\em model debugging} methodology realized through Jacobian Saliency Map (JSM). To cast the problem into a concrete context, we employ Alzheimer's disease (AD) diagnosis as the use case, motivated by its significant impact on human lives and the formidable challenge in its early detection, stemming from the intricate nature of its progression. We introduce an interpretable, multimodal model for AD classification over its multi-stage progression, incorporating JSM as a modality-agnostic tool that provides insights into volumetric changes indicative of brain abnormalities. Our extensive evaluation including ablation study manifests the efficacy of using JSM for model debugging and interpretation, while significantly enhancing model accuracy as well.

Unmasking Dementia Detection by Masking Input Gradients: A JSM Approach to Model Interpretability and Precision

TL;DR

This work tackles trustworthy medical AI for dementia diagnosis by integrating Jacobian Saliency Maps (JSM) as a self-debugging signal during training through a Jacobian-Augmented Loss (JAL). The approach uses a modality-agnostic deformation-based saliency derived from image registration to emphasize disease-relevant brain volume changes, mitigating Clever Hans behavior. Evaluated on the multimodal OASIS-3 dataset with MRI and CT, the method employs early and late fusion in a lightweight 3D CNN and demonstrates substantial improvements in four-class dementia staging, with ablations confirming JAL’s contribution and visual evidence that gradients align with JSM-indicated deformations. The results indicate improved accuracy (up to ~10%), enhanced interpretability, and the potential for more reliable, clinically usable AI in neurodegenerative disease diagnosis.

Abstract

The evolution of deep learning and artificial intelligence has significantly reshaped technological landscapes. However, their effective application in crucial sectors such as medicine demands more than just superior performance, but trustworthiness as well. While interpretability plays a pivotal role, existing explainable AI (XAI) approaches often do not reveal {\em Clever Hans} behavior where a model makes (ungeneralizable) correct predictions using spurious correlations or biases in data. Likewise, current post-hoc XAI methods are susceptible to generating unjustified counterfactual examples. In this paper, we approach XAI with an innovative {\em model debugging} methodology realized through Jacobian Saliency Map (JSM). To cast the problem into a concrete context, we employ Alzheimer's disease (AD) diagnosis as the use case, motivated by its significant impact on human lives and the formidable challenge in its early detection, stemming from the intricate nature of its progression. We introduce an interpretable, multimodal model for AD classification over its multi-stage progression, incorporating JSM as a modality-agnostic tool that provides insights into volumetric changes indicative of brain abnormalities. Our extensive evaluation including ablation study manifests the efficacy of using JSM for model debugging and interpretation, while significantly enhancing model accuracy as well.
Paper Structure (11 sections, 9 equations, 4 figures, 3 tables)

This paper contains 11 sections, 9 equations, 4 figures, 3 tables.

Figures (4)

  • Figure 1: Preprocessing pipelines for MRI and CT scans, involving bias field correction for MRI, contrast stretching for CT to enhance diagnostic values, BET for brain extraction, and registering CT and MRI to MNI152 brain template.
  • Figure 2: The complete pipeline. Model debugging using JSM is integrated into training and takes effect during backpropagation, for each modality (late fusion) or both (early fusion). Final predictions are interpreted by plotting elevated gradients overlaid on input images.
  • Figure 3: Ablation study on JAL in terms of performance histograms. (a-c): Early fusion, (d-f): Late fusion.
  • Figure 4: Visualization of larger gradients in JSM-indicated deformation areas for MRI and CT modalities.