Learning Causal Alignment for Reliable Disease Diagnosis
Mingzhou Liu, Ching-Wen Lee, Xinwei Sun, Yu Qiao, Yizhou Wang
TL;DR
The paper tackles the problem of aligning machine learning–based medical diagnoses with radiologists’ causal decision pathways to avoid spurious correlations. It introduces a causal alignment framework that uses counterfactual generation to identify decision-relevant regions and a causal alignment loss to enforce alignment with radiologist explanations, optimized with the implicit function theorem and conjugate gradient to handle the implicit counterfactual solver. A hierarchical extension leverages attribute annotations via CCCE-based causal attribution, enabling alignment at both attribute and image-region levels. Experiments on lung nodule and breast mass datasets show improved CAM precision and malignancy accuracy, indicating faithful, transferable alignment to expert reasoning. The work advances trustworthy medical AI by grounding model decisions in clinician-driven causality rather than superficial correlations.
Abstract
Aligning the decision-making process of machine learning algorithms with that of experienced radiologists is crucial for reliable diagnosis. While existing methods have attempted to align their diagnosis behaviors to those of radiologists reflected in the training data, this alignment is primarily associational rather than causal, resulting in pseudo-correlations that may not transfer well. In this paper, we propose a causality-based alignment framework towards aligning the model's decision process with that of experts. Specifically, we first employ counterfactual generation to identify the causal chain of model decisions. To align this causal chain with that of experts, we propose a causal alignment loss that enforces the model to focus on causal factors underlying each decision step in the whole causal chain. To optimize this loss that involves the counterfactual generator as an implicit function of the model's parameters, we employ the implicit function theorem equipped with the conjugate gradient method for efficient estimation. We demonstrate the effectiveness of our method on two medical diagnosis applications, showcasing faithful alignment to radiologists.
