Table of Contents
Fetching ...

AXIAL: Attention-based eXplainability for Interpretable Alzheimer's Localized Diagnosis using 2D CNNs on 3D MRI brain scans

Gabriele Lozupone, Alessandro Bria, Francesco Fontanella, Frederick J. A. Meijer, Claudio De Stefano

TL;DR

This work tackles the demand for accurate and explainable AD diagnosis from MRI. It introduces AXIAL, a lightweight, attention-based framework that processes 2D CNN slices from 3D MRIs across axial, coronal, and sagittal planes, learns slice importance, and synthesizes a voxel-level 3D attention map. A double transfer learning strategy leverages knowledge from AD vs CN to improve detection of early-stage changes in MCI, achieving MCCs of 0.712 (AD vs CN) and 0.443 (sMCI vs pMCI) on a standardized ADNI1 subset. The approach yields interpretable MRI insights by localizing clinically known AD regions (hippocampus, amygdala, parahippocampus, inferior lateral ventricles) and demonstrates robustness across cross-validation folds, with open-source code enabling reproducibility and future benchmarking.

Abstract

This study presents an innovative method for Alzheimer's disease diagnosis using 3D MRI designed to enhance the explainability of model decisions. Our approach adopts a soft attention mechanism, enabling 2D CNNs to extract volumetric representations. At the same time, the importance of each slice in decision-making is learned, allowing the generation of a voxel-level attention map to produce an explainable MRI. To test our method and ensure the reproducibility of our results, we chose a standardized collection of MRI data from the Alzheimer's Disease Neuroimaging Initiative (ADNI). On this dataset, our method significantly outperforms state-of-the-art methods in (i) distinguishing AD from cognitive normal (CN) with an accuracy of 0.856 and Matthew's correlation coefficient (MCC) of 0.712, representing improvements of 2.4% and 5.3% respectively over the second-best, and (ii) in the prognostic task of discerning stable from progressive mild cognitive impairment (MCI) with an accuracy of 0.725 and MCC of 0.443, showing improvements of 10.2% and 20.5% respectively over the second-best. We achieved this prognostic result by adopting a double transfer learning strategy, which enhanced sensitivity to morphological changes and facilitated early-stage AD detection. With voxel-level precision, our method identified which specific areas are being paid attention to, identifying these predominant brain regions: the hippocampus, the amygdala, the parahippocampal, and the inferior lateral ventricles. All these areas are clinically associated with AD development. Furthermore, our approach consistently found the same AD-related areas across different cross-validation folds, proving its robustness and precision in highlighting areas that align closely with known pathological markers of the disease.

AXIAL: Attention-based eXplainability for Interpretable Alzheimer's Localized Diagnosis using 2D CNNs on 3D MRI brain scans

TL;DR

This work tackles the demand for accurate and explainable AD diagnosis from MRI. It introduces AXIAL, a lightweight, attention-based framework that processes 2D CNN slices from 3D MRIs across axial, coronal, and sagittal planes, learns slice importance, and synthesizes a voxel-level 3D attention map. A double transfer learning strategy leverages knowledge from AD vs CN to improve detection of early-stage changes in MCI, achieving MCCs of 0.712 (AD vs CN) and 0.443 (sMCI vs pMCI) on a standardized ADNI1 subset. The approach yields interpretable MRI insights by localizing clinically known AD regions (hippocampus, amygdala, parahippocampus, inferior lateral ventricles) and demonstrates robustness across cross-validation folds, with open-source code enabling reproducibility and future benchmarking.

Abstract

This study presents an innovative method for Alzheimer's disease diagnosis using 3D MRI designed to enhance the explainability of model decisions. Our approach adopts a soft attention mechanism, enabling 2D CNNs to extract volumetric representations. At the same time, the importance of each slice in decision-making is learned, allowing the generation of a voxel-level attention map to produce an explainable MRI. To test our method and ensure the reproducibility of our results, we chose a standardized collection of MRI data from the Alzheimer's Disease Neuroimaging Initiative (ADNI). On this dataset, our method significantly outperforms state-of-the-art methods in (i) distinguishing AD from cognitive normal (CN) with an accuracy of 0.856 and Matthew's correlation coefficient (MCC) of 0.712, representing improvements of 2.4% and 5.3% respectively over the second-best, and (ii) in the prognostic task of discerning stable from progressive mild cognitive impairment (MCI) with an accuracy of 0.725 and MCC of 0.443, showing improvements of 10.2% and 20.5% respectively over the second-best. We achieved this prognostic result by adopting a double transfer learning strategy, which enhanced sensitivity to morphological changes and facilitated early-stage AD detection. With voxel-level precision, our method identified which specific areas are being paid attention to, identifying these predominant brain regions: the hippocampus, the amygdala, the parahippocampal, and the inferior lateral ventricles. All these areas are clinically associated with AD development. Furthermore, our approach consistently found the same AD-related areas across different cross-validation folds, proving its robustness and precision in highlighting areas that align closely with known pathological markers of the disease.
Paper Structure (39 sections, 18 equations, 9 figures, 13 tables)

This paper contains 39 sections, 18 equations, 9 figures, 13 tables.

Figures (9)

  • Figure 1: Schematic representation of the proposed diagnostic framework. The Diagnosis and XAI framework processes a 3D sMRI brain image to generate two key outputs: three diagnosis networks identifying the condition as either AD or CN from the three possible slicing axes and a corresponding 3D attention map that can be overlapped to the input image to visually highlight the brain regions the network focuses on to derive its diagnosis.
  • Figure 2: MRI data pre-processing pipeline: (i) Original sMRI, (ii) N4 bias field correction, (iii) MNI152 space normalization, and (iv) skull dissection.
  • Figure 3: The proposed Diagnosis and XAI network: (i) Feature Extraction module, (ii) Attention XAI Fusion module, and (iii) Diagnosis module.
  • Figure 4: Representation of the XAI attention-based approach proposed. Three distinct network are trained for each slicing plane and the slice attention weights for each plane are combined to produce a 3D attention map.
  • Figure 5: Average attention weight distributions generate by our model for each fold and each plane
  • ...and 4 more figures