Unmasking Dementia Detection by Masking Input Gradients: A JSM Approach to Model Interpretability and Precision
Yasmine Mustafa, Tie Luo
TL;DR
This work tackles trustworthy medical AI for dementia diagnosis by integrating Jacobian Saliency Maps (JSM) as a self-debugging signal during training through a Jacobian-Augmented Loss (JAL). The approach uses a modality-agnostic deformation-based saliency derived from image registration to emphasize disease-relevant brain volume changes, mitigating Clever Hans behavior. Evaluated on the multimodal OASIS-3 dataset with MRI and CT, the method employs early and late fusion in a lightweight 3D CNN and demonstrates substantial improvements in four-class dementia staging, with ablations confirming JAL’s contribution and visual evidence that gradients align with JSM-indicated deformations. The results indicate improved accuracy (up to ~10%), enhanced interpretability, and the potential for more reliable, clinically usable AI in neurodegenerative disease diagnosis.
Abstract
The evolution of deep learning and artificial intelligence has significantly reshaped technological landscapes. However, their effective application in crucial sectors such as medicine demands more than just superior performance, but trustworthiness as well. While interpretability plays a pivotal role, existing explainable AI (XAI) approaches often do not reveal {\em Clever Hans} behavior where a model makes (ungeneralizable) correct predictions using spurious correlations or biases in data. Likewise, current post-hoc XAI methods are susceptible to generating unjustified counterfactual examples. In this paper, we approach XAI with an innovative {\em model debugging} methodology realized through Jacobian Saliency Map (JSM). To cast the problem into a concrete context, we employ Alzheimer's disease (AD) diagnosis as the use case, motivated by its significant impact on human lives and the formidable challenge in its early detection, stemming from the intricate nature of its progression. We introduce an interpretable, multimodal model for AD classification over its multi-stage progression, incorporating JSM as a modality-agnostic tool that provides insights into volumetric changes indicative of brain abnormalities. Our extensive evaluation including ablation study manifests the efficacy of using JSM for model debugging and interpretation, while significantly enhancing model accuracy as well.
