Table of Contents
Fetching ...

Multimodal Explainability via Latent Shift applied to COVID-19 stratification

Valerio Guarrasi, Lorenzo Tronchin, Domenico Albano, Eliodoro Faiella, Deborah Fazzini, Domiziana Santucci, Paolo Soda

TL;DR

This paper addresses the need for explainable AI in multimodal medical data by introducing an end-to-end architecture that jointly learns reconstructive and predictive tasks on tabular clinical data and chest X-ray images. It introduces MXAI via latent-shift counterfactuals to reveal modality and feature contributions to predictions. The method is validated on the AIforCOVID dataset, showing competitive performance relative to baselines while providing intrinsic, local explanations aligned with radiologists in a reader study. The work advances trustworthy AI for COVID-19 risk stratification and points to future work on concept-based representations for medical knowledge mining.

Abstract

We are witnessing a widespread adoption of artificial intelligence in healthcare. However, most of the advancements in deep learning in this area consider only unimodal data, neglecting other modalities. Their multimodal interpretation necessary for supporting diagnosis, prognosis and treatment decisions. In this work we present a deep architecture, which jointly learns modality reconstructions and sample classifications using tabular and imaging data. The explanation of the decision taken is computed by applying a latent shift that, simulates a counterfactual prediction revealing the features of each modality that contribute the most to the decision and a quantitative score indicating the modality importance. We validate our approach in the context of COVID-19 pandemic using the AIforCOVID dataset, which contains multimodal data for the early identification of patients at risk of severe outcome. The results show that the proposed method provides meaningful explanations without degrading the classification performance.

Multimodal Explainability via Latent Shift applied to COVID-19 stratification

TL;DR

This paper addresses the need for explainable AI in multimodal medical data by introducing an end-to-end architecture that jointly learns reconstructive and predictive tasks on tabular clinical data and chest X-ray images. It introduces MXAI via latent-shift counterfactuals to reveal modality and feature contributions to predictions. The method is validated on the AIforCOVID dataset, showing competitive performance relative to baselines while providing intrinsic, local explanations aligned with radiologists in a reader study. The work advances trustworthy AI for COVID-19 risk stratification and points to future work on concept-based representations for medical knowledge mining.

Abstract

We are witnessing a widespread adoption of artificial intelligence in healthcare. However, most of the advancements in deep learning in this area consider only unimodal data, neglecting other modalities. Their multimodal interpretation necessary for supporting diagnosis, prognosis and treatment decisions. In this work we present a deep architecture, which jointly learns modality reconstructions and sample classifications using tabular and imaging data. The explanation of the decision taken is computed by applying a latent shift that, simulates a counterfactual prediction revealing the features of each modality that contribute the most to the decision and a quantitative score indicating the modality importance. We validate our approach in the context of COVID-19 pandemic using the AIforCOVID dataset, which contains multimodal data for the early identification of patients at risk of severe outcome. The results show that the proposed method provides meaningful explanations without degrading the classification performance.
Paper Structure (33 sections, 20 equations, 3 figures, 8 tables)

This paper contains 33 sections, 20 equations, 3 figures, 8 tables.

Figures (3)

  • Figure 1: Schematic view of the multimodal deep architecture: for each instance, the input modalities $\boldsymbol{x}_T$ and $\boldsymbol{x}_I$ feed into their corresponding encoders $E_{AE}$ and $E_{CAE}$, obtaining the unimodal embeddings $\boldsymbol{h}_T$ and $\boldsymbol{h}_I$, respectively. These embeddings are then concatenated into the multimodal embedding $\boldsymbol{h}$, which subsequently feeds into the decoders $D_{AE}$, $D_{CAE}$, and the classifier $C_{MLP}$. The resulting outputs are the reconstructions $\hat{\boldsymbol{x}}_T$, $\hat{\boldsymbol{x}}_I$, and classification $\boldsymbol{y}$, respectively. The model is trained by simultaneously minimizing the reconstruction losses $L_{RT}$, $L_{RI}$, and the classification loss $L_C$.
  • Figure 2: Schematic view of the MXAI framework: once the model is trained, each instance's multimodal embedding $\boldsymbol{h}$ feeds into the decoders $D_{AE}$, $D_{CAE}$, and the classifier $C_{MLP}$, according to colors that specify that portion of $h$ is given to each network. The decoders and the classifier provide the original reconstructions $\hat{\boldsymbol{x}}_T$, $\hat{\boldsymbol{x}}_I$, and classification $\boldsymbol{y}$. Via the latent-shift method we obtain a $\lambda > 0$, which gives us a flip in the classification $\boldsymbol{y}^{\lambda}$ by feeding the shifted multimodal embedding $\boldsymbol{h}^{\lambda}$ to $C_{MLP}$. By feeding this new embedding to $D_{AE}$ and $D_{CAE}$, we obtain new reconstructions $\hat{\boldsymbol{x}}^{\lambda}_T$ and $\hat{\boldsymbol{x}}^{\lambda}_I$. By comparing $\boldsymbol{h}$ with $\boldsymbol{h}^{\lambda}$, $\hat{\boldsymbol{x}}_T$ with $\hat{\boldsymbol{x}}^{\lambda}_T$, and $\hat{\boldsymbol{x}}_I$ with $\hat{\boldsymbol{x}}^{\lambda}_I$, we obtain the corresponding multimodal and unimodal explanations, respectively.
  • Figure 3: Four case studies: for each we show the feature importance indicated by our proposal ($\hat{\boldsymbol{\boldsymbol{\Delta}}}_T$ and $\hat{\boldsymbol{\boldsymbol{\Delta}}}_I$) and the corresponding important features indicated by radiologists ($\hat{\boldsymbol{\boldsymbol{\Delta}}}_T^{R_i}$ and $\hat{\boldsymbol{\boldsymbol{\Delta}}}_I^{R_i}$) for the tabular and the imaging modalities, respectively. The rows show examples of patients with mild (first and third row) and severe (second and fourth row) outcomes, for both success (first and second row) and failure cases (third and fourth row) of our model.