Table of Contents
Fetching ...

Mixture of Gaussian-distributed Prototypes with Generative Modelling for Interpretable and Trustworthy Image Recognition

Chong Wang, Yuanhong Chen, Fengbei Liu, Yuyuan Liu, Davis James McCarthy, Helen Frazer, Gustavo Carneiro

TL;DR

This work tackles the interpretability-trustworthiness gap in prototypical-part image recognition by replacing point-based prototypes with a generative mixture of Gaussian prototypes, enabling explicit class-conditional densities $p(\boldsymbol{x}|c)$. By leveraging a memory-bank-augmented EM algorithm and a compact aggregation via class priors, MGProto achieves accurate classification while providing OoD detection through $p(\boldsymbol{x})$, thus improving decision reliability. A novel Tian Ji–inspired prototype mining strategy engages sub-salient object regions to enrich representations, and an importance-prior–based pruning mechanism enables efficient models without significant loss in performance. Empirical results on CUB-200-2011, Stanford Cars, Stanford Dogs, and Oxford-IIIT Pets demonstrate state-of-the-art accuracy, robust OoD detection, and compelling interpretability metrics, suggesting strong practical impact for trustworthy visual recognition systems.

Abstract

Prototypical-part methods, e.g., ProtoPNet, enhance interpretability in image recognition by linking predictions to training prototypes, thereby offering intuitive insights into their decision-making. Existing methods, which rely on a point-based learning of prototypes, typically face two critical issues: 1) the learned prototypes have limited representation power and are not suitable to detect Out-of-Distribution (OoD) inputs, reducing their decision trustworthiness; and 2) the necessary projection of the learned prototypes back into the space of training images causes a drastic degradation in the predictive performance. Furthermore, current prototype learning adopts an aggressive approach that considers only the most active object parts during training, while overlooking sub-salient object regions which still hold crucial classification information. In this paper, we present a new generative paradigm to learn prototype distributions, termed as Mixture of Gaussian-distributed Prototypes (MGProto). The distribution of prototypes from MGProto enables both interpretable image classification and trustworthy recognition of OoD inputs. The optimisation of MGProto naturally projects the learned prototype distributions back into the training image space, thereby addressing the performance degradation caused by prototype projection. Additionally, we develop a novel and effective prototype mining strategy that considers not only the most active but also sub-salient object parts. To promote model compactness, we further propose to prune MGProto by removing prototypes with low importance priors. Experiments on CUB-200-2011, Stanford Cars, Stanford Dogs, and Oxford-IIIT Pets datasets show that MGProto achieves state-of-the-art image recognition and OoD detection performances, while providing encouraging interpretability results.

Mixture of Gaussian-distributed Prototypes with Generative Modelling for Interpretable and Trustworthy Image Recognition

TL;DR

This work tackles the interpretability-trustworthiness gap in prototypical-part image recognition by replacing point-based prototypes with a generative mixture of Gaussian prototypes, enabling explicit class-conditional densities . By leveraging a memory-bank-augmented EM algorithm and a compact aggregation via class priors, MGProto achieves accurate classification while providing OoD detection through , thus improving decision reliability. A novel Tian Ji–inspired prototype mining strategy engages sub-salient object regions to enrich representations, and an importance-prior–based pruning mechanism enables efficient models without significant loss in performance. Empirical results on CUB-200-2011, Stanford Cars, Stanford Dogs, and Oxford-IIIT Pets demonstrate state-of-the-art accuracy, robust OoD detection, and compelling interpretability metrics, suggesting strong practical impact for trustworthy visual recognition systems.

Abstract

Prototypical-part methods, e.g., ProtoPNet, enhance interpretability in image recognition by linking predictions to training prototypes, thereby offering intuitive insights into their decision-making. Existing methods, which rely on a point-based learning of prototypes, typically face two critical issues: 1) the learned prototypes have limited representation power and are not suitable to detect Out-of-Distribution (OoD) inputs, reducing their decision trustworthiness; and 2) the necessary projection of the learned prototypes back into the space of training images causes a drastic degradation in the predictive performance. Furthermore, current prototype learning adopts an aggressive approach that considers only the most active object parts during training, while overlooking sub-salient object regions which still hold crucial classification information. In this paper, we present a new generative paradigm to learn prototype distributions, termed as Mixture of Gaussian-distributed Prototypes (MGProto). The distribution of prototypes from MGProto enables both interpretable image classification and trustworthy recognition of OoD inputs. The optimisation of MGProto naturally projects the learned prototype distributions back into the training image space, thereby addressing the performance degradation caused by prototype projection. Additionally, we develop a novel and effective prototype mining strategy that considers not only the most active but also sub-salient object parts. To promote model compactness, we further propose to prune MGProto by removing prototypes with low importance priors. Experiments on CUB-200-2011, Stanford Cars, Stanford Dogs, and Oxford-IIIT Pets datasets show that MGProto achieves state-of-the-art image recognition and OoD detection performances, while providing encouraging interpretability results.
Paper Structure (27 sections, 14 equations, 11 figures, 11 tables, 1 algorithm)

This paper contains 27 sections, 14 equations, 11 figures, 11 tables, 1 algorithm.

Figures (11)

  • Figure 1: (a) Current prototypical-part networks are softmax-based discriminative classifiers, forming a point-based learning of prototypes with limited representation power, which are challenged by the detection of OoD inputs. (b) Our method learns a mixture of Gaussian-distributed prototypes with a generative modelling, enabling not only interpretable image classification but also trustworthy recognition of OoD samples.
  • Figure 2: Current prototype-based methods (e.g., ProtoPNet chen2019looks and TesNet wang2021interpretable) suffer from drastic performance degradation following the prototype replacement step (denoted by the dotted vertical lines) at each round of the multi-stage training, whereas our MGProto method does not encounter this problem. These curves are obtained from models trained on CUB-200-2011 using a ResNet34 backbone.
  • Figure 3: T-SNE representations of prototypes (stars) and the nearest training patch features (dots), from ProtoPNet (a) and TesNet (b), trained on CUB-200-2011. We show 5 random classes (out of 200) for better visualisation, where each colour denotes a different class.
  • Figure 4: The overall framework of the MGProto method. For a given image $\mathbf{x}$, the model backbone (e.g., ResNet) extracts initial features $f_{\theta_{\textbf{bcb}}}(\mathbf{x})$ that are then fed to the add-on layer $f_{\theta_{\textbf{add}}}$ to obtain feature maps $\mathbf{F}$. An auxiliary loss $\mathcal{L}_{aux}$ is applied on $f_{\theta_{\textbf{bcb}}}(\mathbf{x})$ to improve the backbone's feature extraction ability. (a) The case-based interpretation is achieved by fitting the feature representation $\mathbf{F}$ into the mixture of Gaussian-distributed prototypes, yielding the class-conditional data probability $p(\mathbf{x}|c)$ that enables the determination of whether the input is OoD. Bayes' theorem is then used to derive the posterior class probability $p(c|\mathbf{x})$ for predicting the image category and computing the cross-entropy loss $\mathcal{L}_{ce}$. For simplicity, here we show only 2 prototypes for class $c_1$, thus only 2 relevant features from the most active image patches are stored to the memory queue of class $c_1$. (b) For each class, the mixture of Gaussian-distributed prototypes is estimated by a modified EM algorithm to encourage prototype diversity.
  • Figure 5: (a) Illustration of different levels of active patches for prototype mining. For clarity, here we suppose only two prototypes in each class and consider $T=3$ levels of active patches. (b) Diagram of the classic Tian Ji's horse racing legend. (c) Our proposed prototype mining strategy establishes $T-1$ mining competitions (solid lines). The standard classification supervision is represented by the dash line.
  • ...and 6 more figures