Multi-Head Explainer: A General Framework to Improve Explainability in CNNs and Transformers
Bohang Sun, Pietro Liò
TL;DR
Multi-Head Explainer (MHEX) addresses explainability gaps in CNNs and Transformers by introducing a modular framework that couples an Attention Gate, Deep Supervision, and an Equivalent Matrix to produce richer saliency maps without sacrificing accuracy. The framework yields CNN saliency through Class Activation Maps (CAMs) refined by the Equivalent Matrix and per-layer integration, while Transformer explainability collects token-level saliency across layers with controlled smoothing. Key contributions include a non-negativity constraint to reduce noise, tailored saliency metrics (AVG Drop, SAD, EAD), and a gradient-based collaboration analysis to study interactions between MHEX components. Empirical results on ImageNet1k, MedMNIST, and AG News demonstrate improved accuracy and more interpretable explanations across CNNs and Transformers, with guidelines for integration into residuals and attention stacks and potential extension to GNNs and segmentation tasks.
Abstract
In this study, we introduce the Multi-Head Explainer (MHEX), a versatile and modular framework that enhances both the explainability and accuracy of Convolutional Neural Networks (CNNs) and Transformer-based models. MHEX consists of three core components: an Attention Gate that dynamically highlights task-relevant features, Deep Supervision that guides early layers to capture fine-grained details pertinent to the target class, and an Equivalent Matrix that unifies refined local and global representations to generate comprehensive saliency maps. Our approach demonstrates superior compatibility, enabling effortless integration into existing residual networks like ResNet and Transformer architectures such as BERT with minimal modifications. Extensive experiments on benchmark datasets in medical imaging and text classification show that MHEX not only improves classification accuracy but also produces highly interpretable and detailed saliency scores.
