Table of Contents
Fetching ...

Improving Prototypical Visual Explanations with Reward Reweighing, Reselection, and Retraining

Aaron J. Li, Robin Netzorg, Zhihan Cheng, Zhuoqin Zhang, Bin Yu

TL;DR

This work tackles the interpretability gap in prototype-based image classifiers by introducing R3, a post-processing framework that learns a reward model from human feedback to guide three offline prototype updates—Reward Reweighing, Prototype Reselection, and Retraining—for ProtoPNet variants. The reward model is trained from a compact set of human ratings on image-prototype activations and used to align prototypes with human preferences, while a subsequent retraining step realigns the base features and classifier to preserve predictive accuracy. Across bird and car datasets, R3 consistently boosts prototype quality (reward and Activation Precision) and often improves or restores predictive accuracy after retraining, with demonstrated generalizability to ProtoPFormer and SDFA/SA-augmented variants. The framework enables customizable, human-aligned interpretability improvements that can be adapted to other prototype-based models, offering practical benefits for trustworthy visual explanations and deployment in real-world settings.

Abstract

In recent years, work has gone into developing deep interpretable methods for image classification that clearly attributes a model's output to specific features of the data. One such of these methods is the Prototypical Part Network (ProtoPNet), which attempts to classify images based on meaningful parts of the input. While this architecture is able to produce visually interpretable classifications, it often learns to classify based on parts of the image that are not semantically meaningful. To address this problem, we propose the Reward Reweighing, Reselecting, and Retraining (R3) post-processing framework, which performs three additional corrective updates to a pretrained ProtoPNet in an offline and efficient manner. The first two steps involve learning a reward model based on collected human feedback and then aligning the prototypes with human preferences. The final step is retraining, which realigns the base features and the classifier layer of the original model with the updated prototypes. We find that our R3 framework consistently improves both the interpretability and the predictive accuracy of ProtoPNet and its variants.

Improving Prototypical Visual Explanations with Reward Reweighing, Reselection, and Retraining

TL;DR

This work tackles the interpretability gap in prototype-based image classifiers by introducing R3, a post-processing framework that learns a reward model from human feedback to guide three offline prototype updates—Reward Reweighing, Prototype Reselection, and Retraining—for ProtoPNet variants. The reward model is trained from a compact set of human ratings on image-prototype activations and used to align prototypes with human preferences, while a subsequent retraining step realigns the base features and classifier to preserve predictive accuracy. Across bird and car datasets, R3 consistently boosts prototype quality (reward and Activation Precision) and often improves or restores predictive accuracy after retraining, with demonstrated generalizability to ProtoPFormer and SDFA/SA-augmented variants. The framework enables customizable, human-aligned interpretability improvements that can be adapted to other prototype-based models, offering practical benefits for trustworthy visual explanations and deployment in real-world settings.

Abstract

In recent years, work has gone into developing deep interpretable methods for image classification that clearly attributes a model's output to specific features of the data. One such of these methods is the Prototypical Part Network (ProtoPNet), which attempts to classify images based on meaningful parts of the input. While this architecture is able to produce visually interpretable classifications, it often learns to classify based on parts of the image that are not semantically meaningful. To address this problem, we propose the Reward Reweighing, Reselecting, and Retraining (R3) post-processing framework, which performs three additional corrective updates to a pretrained ProtoPNet in an offline and efficient manner. The first two steps involve learning a reward model based on collected human feedback and then aligning the prototypes with human preferences. The final step is retraining, which realigns the base features and the classifier layer of the original model with the updated prototypes. We find that our R3 framework consistently improves both the interpretability and the predictive accuracy of ProtoPNet and its variants.
Paper Structure (29 sections, 2 equations, 5 figures, 10 tables, 1 algorithm)

This paper contains 29 sections, 2 equations, 5 figures, 10 tables, 1 algorithm.

Figures (5)

  • Figure 1: Rubric used for human feedback on the activation patterns of predictions for birds from the CUB-200-2011 dataset. First, the rater estimates a base score based on overlap proportion, and then an optional adjustment $\delta \in \{-1, 1\}$ could be given based on how meaningful or characteristic the focused body part is.
  • Figure 2: An overview of the R3-ProtoPNet framework. Dashed arrows indicate our R3 debugging procedure.
  • Figure 3: Trade-off curves between model accuracy and model interpretability. The plot is qualitative.
  • Figure 4: Closest training patches of the five prototypes of ProtoPNet (top row), R2-ProtoPNet (middle row), and R3-ProtoPNet (bottom row) within the same class (5 prototypes per class). Each cluster of 3 rows of images is a seperate class.
  • Figure 5: Prototype projections on the same image (each column) from ProtoPNet (top row), R2-ProtoPNet (middle row), and R3-ProtoPNet (bottom row).