Leveraging counterfactual concepts for debugging and improving CNN model performance
Syed Ali Tariq, Tehseen Zia
TL;DR
The work tackles the challenge of using explanations to boost CNN performance by leveraging counterfactual concepts to identify class-specific important filters. It builds on a counterfactual filter identification model to produce per-class key filters and introduces a novel loss, $L_d = L_{CE} - \lambda_1 L_{MC} + \lambda_2 L_{nonMC}$, that retrains the model to activate those filters while suppressing irrelevant ones. Through misclassification analysis and targeted retraining, the approach achieves a 1-2% improvement in accuracy on a bird dataset and provides insights into biases and weaknesses in decision-making. The methodology offers a practical pathway to combine explainability with performance gains, with potential impact in domains where exacting accuracy and transparency are critical.
Abstract
Counterfactual explanation methods have recently received significant attention for explaining CNN-based image classifiers due to their ability to provide easily understandable explanations that align more closely with human reasoning. However, limited attention has been given to utilizing explainability methods to improve model performance. In this paper, we propose to leverage counterfactual concepts aiming to enhance the performance of CNN models in image classification tasks. Our proposed approach utilizes counterfactual reasoning to identify crucial filters used in the decision-making process. Following this, we perform model retraining through the design of a novel methodology and loss functions that encourage the activation of class-relevant important filters and discourage the activation of irrelevant filters for each class. This process effectively minimizes the deviation of activation patterns of local predictions and the global activation patterns of their respective inferred classes. By incorporating counterfactual explanations, we validate unseen model predictions and identify misclassifications. The proposed methodology provides insights into potential weaknesses and biases in the model's learning process, enabling targeted improvements and enhanced performance. Experimental results on publicly available datasets have demonstrated an improvement of 1-2\%, validating the effectiveness of the approach.
