Improving Prototypical Visual Explanations with Reward Reweighing, Reselection, and Retraining
Aaron J. Li, Robin Netzorg, Zhihan Cheng, Zhuoqin Zhang, Bin Yu
TL;DR
This work tackles the interpretability gap in prototype-based image classifiers by introducing R3, a post-processing framework that learns a reward model from human feedback to guide three offline prototype updates—Reward Reweighing, Prototype Reselection, and Retraining—for ProtoPNet variants. The reward model is trained from a compact set of human ratings on image-prototype activations and used to align prototypes with human preferences, while a subsequent retraining step realigns the base features and classifier to preserve predictive accuracy. Across bird and car datasets, R3 consistently boosts prototype quality (reward and Activation Precision) and often improves or restores predictive accuracy after retraining, with demonstrated generalizability to ProtoPFormer and SDFA/SA-augmented variants. The framework enables customizable, human-aligned interpretability improvements that can be adapted to other prototype-based models, offering practical benefits for trustworthy visual explanations and deployment in real-world settings.
Abstract
In recent years, work has gone into developing deep interpretable methods for image classification that clearly attributes a model's output to specific features of the data. One such of these methods is the Prototypical Part Network (ProtoPNet), which attempts to classify images based on meaningful parts of the input. While this architecture is able to produce visually interpretable classifications, it often learns to classify based on parts of the image that are not semantically meaningful. To address this problem, we propose the Reward Reweighing, Reselecting, and Retraining (R3) post-processing framework, which performs three additional corrective updates to a pretrained ProtoPNet in an offline and efficient manner. The first two steps involve learning a reward model based on collected human feedback and then aligning the prototypes with human preferences. The final step is retraining, which realigns the base features and the classifier layer of the original model with the updated prototypes. We find that our R3 framework consistently improves both the interpretability and the predictive accuracy of ProtoPNet and its variants.
