Table of Contents
Fetching ...

A Knowledge Distillation-Based Approach to Enhance Transparency of Classifier Models

Yuchen Jiang, Xinyuan Zhao, Yihang Wu, Ahmad Chaddad

TL;DR

This work tackles the need for explainable AI in medical image analysis by introducing a KD-based framework that distills a DenseNet121 teacher into a shallow five-layer student. It uses KD-FMV, combining a hard loss $\ell_{HL}$ and a soft loss $\ell_{SL}$ with temperature $T$ to transfer feature representations via $\mathcal{L}_{distill}=\alpha \cdot \ell_{HL} + (1-\alpha) \ell_{SL}$, while leveraging average feature maps to visualize per-layer decision processes. Grad-CAM and SHAP are employed to validate interpretability, and the approach is evaluated on brain tumor, eye disease, and Alzheimer's datasets, with the student achieving near-teacher accuracy (and sometimes surpassing it) while reducing model depth. Additionally, the method reduces FLOPs and mean execution time, enabling faster and more efficient interpretability suitable for resource-limited clinical settings.

Abstract

With the rapid development of artificial intelligence (AI), especially in the medical field, the need for its explainability has grown. In medical image analysis, a high degree of transparency and model interpretability can help clinicians better understand and trust the decision-making process of AI models. In this study, we propose a Knowledge Distillation (KD)-based approach that aims to enhance the transparency of the AI model in medical image analysis. The initial step is to use traditional CNN to obtain a teacher model and then use KD to simplify the CNN architecture, retain most of the features of the data set, and reduce the number of network layers. It also uses the feature map of the student model to perform hierarchical analysis to identify key features and decision-making processes. This leads to intuitive visual explanations. We selected three public medical data sets (brain tumor, eye disease, and Alzheimer's disease) to test our method. It shows that even when the number of layers is reduced, our model provides a remarkable result in the test set and reduces the time required for the interpretability analysis.

A Knowledge Distillation-Based Approach to Enhance Transparency of Classifier Models

TL;DR

This work tackles the need for explainable AI in medical image analysis by introducing a KD-based framework that distills a DenseNet121 teacher into a shallow five-layer student. It uses KD-FMV, combining a hard loss and a soft loss with temperature to transfer feature representations via , while leveraging average feature maps to visualize per-layer decision processes. Grad-CAM and SHAP are employed to validate interpretability, and the approach is evaluated on brain tumor, eye disease, and Alzheimer's datasets, with the student achieving near-teacher accuracy (and sometimes surpassing it) while reducing model depth. Additionally, the method reduces FLOPs and mean execution time, enabling faster and more efficient interpretability suitable for resource-limited clinical settings.

Abstract

With the rapid development of artificial intelligence (AI), especially in the medical field, the need for its explainability has grown. In medical image analysis, a high degree of transparency and model interpretability can help clinicians better understand and trust the decision-making process of AI models. In this study, we propose a Knowledge Distillation (KD)-based approach that aims to enhance the transparency of the AI model in medical image analysis. The initial step is to use traditional CNN to obtain a teacher model and then use KD to simplify the CNN architecture, retain most of the features of the data set, and reduce the number of network layers. It also uses the feature map of the student model to perform hierarchical analysis to identify key features and decision-making processes. This leads to intuitive visual explanations. We selected three public medical data sets (brain tumor, eye disease, and Alzheimer's disease) to test our method. It shows that even when the number of layers is reduced, our model provides a remarkable result in the test set and reduces the time required for the interpretability analysis.

Paper Structure

This paper contains 1 section, 7 equations, 4 figures, 4 tables, 1 algorithm.

Table of Contents

  1. Introduction

Figures (4)

  • Figure 1: Flowchart of distilling knowledge and explainable AI. 1) distilling knowledge from DenseNet121 to a smaller model; then, conducting interpretability analysis. It starts by training DenseNet121, and then it distills knowledge into a five-layer custom CNN by minimizing the loss. 2) It obtains feature maps layer by layer and computes their averages, and 3) through color mapping, it identifies the key parts of the image that influence the model decision.
  • Figure 2: Confusion matrix of the student (Left) and teacher (Right) model on brain tumor, eyes disease and alzheimer datasets, respectively. In brain tumor dataset, 0, 1, 2, and 3 represent Glioma tumor, Meningioma, No tumor and Pituitary, respectively. In alzheimer dataset, 0, 1, 2 and 3 indicate Mild demented, Moderate demented, Non demented and Very mild demented, respectively. In Eye-disease data set, 0, 1, 2 and 3 denote Cataract, Diabetic retinopathy, Glaucoma and Normal, respectively.
  • Figure 3: Example of the original, Grad-CAM, SHAP (Teacher) and the proposed method (Student) images that are selected randomly from the test set of each class in the three data sets for interpretability analysis. Specifically, the analyses using Grad-CAM and SHAP rely on teacher models, whereas the proposed method uses student models. The student model has a total of five convolutional layers, and uses the average feature map to view the features of each layer. Each row from left to right represents 12 disease classes from three datasets, namely: Pituitary, Meningioma, No tumor, Glioma, Cataract, Diabetic retinopathy, Glaucoma, Normal, Mild demented, Moderate demented, Non demented, and Very mild demented.
  • Figure 4: Example of the best student model to perform Grad-CAM and SHAP methods. The first, second and third rows represent the original image, Grad-CAM heatmap and SHAP, respectively. Each column from left to right represents 12 disease classes, namely: Pituitary, Meningioma, No tumor, Glioma, Cataract, Diabetic retinopathy, Glaucoma, Normal, Mild demented, Moderate demented, Non demented, and Very mild demented.