Table of Contents
Fetching ...

Attri-Net: A Globally and Locally Inherently Interpretable Model for Multi-Label Classification Using Class-Specific Counterfactuals

Susu Sun, Stefano Woerner, Andreas Maier, Lisa M. Koch, Christian F. Baumgartner

TL;DR

Attri-Net addresses the interpretability gap in multi-label medical image classification by learning class-specific counterfactual attribution maps that directly drive linear classifiers. The model yields faithful local explanations via weighted attribution maps and global explanations through learned class centers and classifier weights, with an optional guidance mechanism to align explanations with human knowledge. Empirical results on CheXpert, ChestX-ray8, and VinDr-CXR show competitive classification performance and superior explainability, including the ability to detect and mitigate shortcut learning through global explanations and guidance. The approach holds promise for safer clinical deployment by combining faithful, interpretable reasoning with robust performance and a practical pathway for incorporating expert annotations.

Abstract

Interpretability is crucial for machine learning algorithms in high-stakes medical applications. However, high-performing neural networks typically cannot explain their predictions. Post-hoc explanation methods provide a way to understand neural networks but have been shown to suffer from conceptual problems. Moreover, current research largely focuses on providing local explanations for individual samples rather than global explanations for the model itself. In this paper, we propose Attri-Net, an inherently interpretable model for multi-label classification that provides local and global explanations. Attri-Net first counterfactually generates class-specific attribution maps to highlight the disease evidence, then performs classification with logistic regression classifiers based solely on the attribution maps. Local explanations for each prediction can be obtained by interpreting the attribution maps weighted by the classifiers' weights. Global explanation of whole model can be obtained by jointly considering learned average representations of the attribution maps for each class (called the class centers) and the weights of the linear classifiers. To ensure the model is ``right for the right reason", we further introduce a mechanism to guide the model's explanations to align with human knowledge. Our comprehensive evaluations show that Attri-Net can generate high-quality explanations consistent with clinical knowledge while not sacrificing classification performance.

Attri-Net: A Globally and Locally Inherently Interpretable Model for Multi-Label Classification Using Class-Specific Counterfactuals

TL;DR

Attri-Net addresses the interpretability gap in multi-label medical image classification by learning class-specific counterfactual attribution maps that directly drive linear classifiers. The model yields faithful local explanations via weighted attribution maps and global explanations through learned class centers and classifier weights, with an optional guidance mechanism to align explanations with human knowledge. Empirical results on CheXpert, ChestX-ray8, and VinDr-CXR show competitive classification performance and superior explainability, including the ability to detect and mitigate shortcut learning through global explanations and guidance. The approach holds promise for safer clinical deployment by combining faithful, interpretable reasoning with robust performance and a practical pathway for incorporating expert annotations.

Abstract

Interpretability is crucial for machine learning algorithms in high-stakes medical applications. However, high-performing neural networks typically cannot explain their predictions. Post-hoc explanation methods provide a way to understand neural networks but have been shown to suffer from conceptual problems. Moreover, current research largely focuses on providing local explanations for individual samples rather than global explanations for the model itself. In this paper, we propose Attri-Net, an inherently interpretable model for multi-label classification that provides local and global explanations. Attri-Net first counterfactually generates class-specific attribution maps to highlight the disease evidence, then performs classification with logistic regression classifiers based solely on the attribution maps. Local explanations for each prediction can be obtained by interpreting the attribution maps weighted by the classifiers' weights. Global explanation of whole model can be obtained by jointly considering learned average representations of the attribution maps for each class (called the class centers) and the weights of the linear classifiers. To ensure the model is ``right for the right reason", we further introduce a mechanism to guide the model's explanations to align with human knowledge. Our comprehensive evaluations show that Attri-Net can generate high-quality explanations consistent with clinical knowledge while not sacrificing classification performance.
Paper Structure (45 sections, 9 equations, 17 figures, 10 tables)

This paper contains 45 sections, 9 equations, 17 figures, 10 tables.

Figures (17)

  • Figure 1: Overview of the Attri-Net framework. Given an input image $\textbf{x}$ and a diagnostic task $\mathbf{t_c}$, the visual feature attribution generator (a) produces counterfactual attribution maps $M_c(\mathbf{x})$ that highlight specific disease effects. Logistic regression classifiers in (b) produce the final prediction for each class based on downsampled versions of these attribution maps.
  • Figure 2: Examples of counterfactual images. The top row shows an input image $\mathbf{x}$ that is positive in cardiomegaly, the attribution map $M_c(\mathbf{x})$ (with a flipped sign for better visualization effect), and the counterfactual image $\hat{\mathbf{x}}$. The bottom row shows images for a negative sample. As expected the residual changes $M_c(\mathbf{x})$ are large for the positive sample and small for the negative sample.
  • Figure 3: Generation of pseudo guidance masks. An example for the disease cardiomegaly from the ChestX-ray8 dataset is shown. (a) A chest X-ray image with its ground truth bounding box annotation for cardiomegaly. (b) The same image with multiple cardiomegaly bounding box annotations from other cases in the ChestX-ray8 dataset. (c) The same image with the binary pseudo-mask generated from multiple cardiomegaly bounding box annotations.
  • Figure 4: Local explanation of Attri-Net for an example from the CheXpert dataset with cardiomegaly. The weighted attribution map serves as local explanation for a specific prediction. It is defined as the element-wise product of the attribution map from class attribution generator and the weight matrix from corresponding logistic regression classifier.
  • Figure 5: Global explanation of Attri-Net for a model trained on the CheXpert dataset. The positive and negative class centers of attribution maps and the corresponding classifiers’ weight matrix together provide a global explanation for Attri-Net.
  • ...and 12 more figures