Table of Contents
Fetching ...

Cross-modal Causal Intervention for Alzheimer's Disease Prediction

Yutao Jin, Haowen Xiao, Junyong Zhai, Yuxiao Li, Jielei Chu, Fengmao Lv, Yuxiao Li

TL;DR

MediAD addresses Alzheimer's disease diagnosis under confounding from unobserved variables by grounding multi-modal predictions in a structural causal model. It introduces a cross-modal Causal Fusion module to generate a mediator from fused visual and textual features and applies a Front-Door Adjustment to mitigate confounding effects, aided by a consistency loss. Textual inputs are enriched with LLM-generated clinical summaries, creating a richer multi-modal representation alongside MRI-derived features. Experiments on NACC and ADNI show MediAD achieving state-of-the-art or competitive accuracy for CN/MCI/AD classification and CN/AD binary tasks, validating the efficacy of combining causal intervention with multi-modal learning in neurological diagnosis.

Abstract

Mild Cognitive Impairment (MCI) serves as a prodromal stage of Alzheimer's Disease (AD), where early identification and intervention can effectively slow the progression to dementia. However, diagnosing AD remains a significant challenge in neurology due to the confounders caused mainly by the selection bias of multi-modal data and the complex relationships between variables. To address these issues, we propose a novel visual-language causality-inspired framework named Cross-modal Causal Intervention with Mediator for Alzheimer's Disease Diagnosis (MediAD) for diagnostic assistance. Our MediAD employs Large Language Models (LLMs) to summarize clinical data under strict templates, therefore enriching textual inputs. The MediAD model utilizes Magnetic Resonance Imaging (MRI), clinical data, and textual data enriched by LLMs to classify participants into Cognitively Normal (CN), MCI, and AD categories. Because of the presence of confounders, such as cerebral vascular lesions and age-related biomarkers, non-causal models are likely to capture spurious input-output correlations, generating less reliable results. Our framework implicitly mitigates the effect of both observable and unobservable confounders through a unified causal intervention method. Experimental results demonstrate the outstanding performance of our method in distinguishing CN/MCI/AD cases, outperforming other methods in most evaluation metrics. The study showcases the potential of integrating causal reasoning with multi-modal learning for neurological disease diagnosis.

Cross-modal Causal Intervention for Alzheimer's Disease Prediction

TL;DR

MediAD addresses Alzheimer's disease diagnosis under confounding from unobserved variables by grounding multi-modal predictions in a structural causal model. It introduces a cross-modal Causal Fusion module to generate a mediator from fused visual and textual features and applies a Front-Door Adjustment to mitigate confounding effects, aided by a consistency loss. Textual inputs are enriched with LLM-generated clinical summaries, creating a richer multi-modal representation alongside MRI-derived features. Experiments on NACC and ADNI show MediAD achieving state-of-the-art or competitive accuracy for CN/MCI/AD classification and CN/AD binary tasks, validating the efficacy of combining causal intervention with multi-modal learning in neurological diagnosis.

Abstract

Mild Cognitive Impairment (MCI) serves as a prodromal stage of Alzheimer's Disease (AD), where early identification and intervention can effectively slow the progression to dementia. However, diagnosing AD remains a significant challenge in neurology due to the confounders caused mainly by the selection bias of multi-modal data and the complex relationships between variables. To address these issues, we propose a novel visual-language causality-inspired framework named Cross-modal Causal Intervention with Mediator for Alzheimer's Disease Diagnosis (MediAD) for diagnostic assistance. Our MediAD employs Large Language Models (LLMs) to summarize clinical data under strict templates, therefore enriching textual inputs. The MediAD model utilizes Magnetic Resonance Imaging (MRI), clinical data, and textual data enriched by LLMs to classify participants into Cognitively Normal (CN), MCI, and AD categories. Because of the presence of confounders, such as cerebral vascular lesions and age-related biomarkers, non-causal models are likely to capture spurious input-output correlations, generating less reliable results. Our framework implicitly mitigates the effect of both observable and unobservable confounders through a unified causal intervention method. Experimental results demonstrate the outstanding performance of our method in distinguishing CN/MCI/AD cases, outperforming other methods in most evaluation metrics. The study showcases the potential of integrating causal reasoning with multi-modal learning for neurological disease diagnosis.

Paper Structure

This paper contains 25 sections, 15 equations, 8 figures, 6 tables, 1 algorithm.

Figures (8)

  • Figure 1: Doctors use multi-modal medical data to achieve more accurate diagnoses through mediation analysis and confounding factor screening. Confounding factor screening is often used in scenarios where confounders are observable, whereas mediation analysis is frequently employed for accurate clinical diagnosis in settings where confounders are unobservable.
  • Figure 2: The structural causal model (SCM). (a) illustrates the causal pathway from input $X$ to output $Y$, where both $X$ and $Y$ are influenced by the confounding factor S. (b) demonstrates the application of front-door adjustment to decompose the causal pathway from $X$ to $Y$ and estimate the causal effect between them when confounders are unobservable. (c) illustrates the structural causal model (SCM) diagram constructed based on our proposed framework. (d) demonstrates the implementation of front-door adjustment in our framework to address unobservable visual and textual confounders by selecting appropriate mediators for multi-modal features.
  • Figure 3: The structure of cross-modal Causal Fusion (CF) module. Visual feature and textual feature undergo cross-attention computation and are concatenated. The fused features are then passed through the CaaM Caam to further mitigate the effects of confounders, ultimately outputting the refined representation as the mediator.
  • Figure 4: In the SCM, the expression $do(F=f)$ denotes an intervention on the multi-modal feature $F$, which graphically removes all incoming arrows to $F$. A condition for the front-door adjustment is that there exists no unblocked backdoor path between $F$ and the mediator $M$. The condition will be violated if an unobserved confounder $S$ creates such a backdoor path (e.g., $F \leftarrow S\rightarrow M$). Consequently, the observational probability $P(M|F)$ will not equal the interventional probability $P(M|do(F))$. Conversely, $P(M|do(F)) = P(M|F)$ implies that no such unblocked backdoor path exists.
  • Figure 5: An overview of our proposed MediAD. MRI/fMRI scans serve as the visual input, while the textual input consists of summaries generated by Large Language Models (LLMs) and structured clinical data. Visual features are extracted using a 3D-CNN. Within the MediAD framework, these visual and textual features are processed by the cross-modal Causal Fusion (CF) module to generate mediators. The mediators are optimized with a consistency regularization loss ($\mathcal{L}_{cl}$). The multi-modal features and the optimized mediators are then fed into the Front-Door Adjustment (FDA) module to mitigate the effect of confounders. Finally, the refined features from this module are passed to a classifier for final diagnosis.
  • ...and 3 more figures