Knowledge Distillation of Convolutional Neural Networks through Feature Map Transformation using Decision Trees
Maddimsetti Srinivas, Debdoot Sheet
TL;DR
The paper addresses the interpretability of CNNs in medical imaging by distilling decisions into an interpretable surrogate. It extracts the final CNN feature maps, flattens them to a $1024$-dimensional vector, maps to a $4$-dimensional feature vector via a fully connected layer, and trains a decision tree as an interpretable surrogate under depth and node constraints. The key contribution is demonstrating that a decision tree on this compact feature representation can achieve competitive accuracy with a CNN on MedMNIST datasets (dermaMNIST, octMNIST, pneumoniaMNIST), along with analysis of feature distributions and inter-feature correlations. This approach provides a practical path to transparent, low-complexity reasoning in CNNs and can be extended to additional layers to balance accuracy and explainability.
Abstract
The interpretation of reasoning by Deep Neural Networks (DNN) is still challenging due to their perceived black-box nature. Therefore, deploying DNNs in several real-world tasks is restricted by the lack of transparency of these models. We propose a distillation approach by extracting features from the final layer of the convolutional neural network (CNN) to address insights to its reasoning. The feature maps in the final layer of a CNN are transformed into a one-dimensional feature vector using a fully connected layer. Subsequently, the extracted features are used to train a decision tree to achieve the best accuracy under constraints of depth and nodes. We use the medical images of dermaMNIST, octMNIST, and pneumoniaMNIST from the medical MNIST datasets to demonstrate our proposed work. We observed that performance of the decision tree is as good as a CNN with minimum complexity. The results encourage interpreting decisions made by the CNNs using decision trees.
