Table of Contents
Fetching ...

Learning with Mixture of Prototypes for Out-of-Distribution Detection

Haodong Lu, Dong Gong, Shuo Wang, Jason Xue, Lina Yao, Kristen Moore

TL;DR

This paper addresses OOD detection by challenging the common assumption that each ID class can be represented by a single prototype. It introduces PALM, a framework that learns a mixture of class-specific prototypes on a hyperspherical embedding space, using reciprocal-neighbor soft assignments to dynamically allocate samples to multiple prototypes. The approach combines a maximum likelihood loss with a prototype-level contrastive loss and updates prototypes via EMA, enabling accurate ID-OOD discrimination and extending naturally to unsupervised OOD detection. Empirically, PALM achieves state-of-the-art AUROC on CIFAR-100, notable FPR improvements, and strong performance in large-scale and unsupervised settings, with modest computational overhead. These results suggest PALM provides a more faithful representation of data diversity and robust OOD detection in open-world scenarios.

Abstract

Out-of-distribution (OOD) detection aims to detect testing samples far away from the in-distribution (ID) training data, which is crucial for the safe deployment of machine learning models in the real world. Distance-based OOD detection methods have emerged with enhanced deep representation learning. They identify unseen OOD samples by measuring their distances from ID class centroids or prototypes. However, existing approaches learn the representation relying on oversimplified data assumptions, e.g, modeling ID data of each class with one centroid class prototype or using loss functions not designed for OOD detection, which overlook the natural diversities within the data. Naively enforcing data samples of each class to be compact around only one prototype leads to inadequate modeling of realistic data and limited performance. To tackle these issues, we propose PrototypicAl Learning with a Mixture of prototypes (PALM) which models each class with multiple prototypes to capture the sample diversities, and learns more faithful and compact samples embeddings to enhance OOD detection. Our method automatically identifies and dynamically updates prototypes, assigning each sample to a subset of prototypes via reciprocal neighbor soft assignment weights. PALM optimizes a maximum likelihood estimation (MLE) loss to encourage the sample embeddings to be compact around the associated prototypes, as well as a contrastive loss on all prototypes to enhance intra-class compactness and inter-class discrimination at the prototype level. Moreover, the automatic estimation of prototypes enables our approach to be extended to the challenging OOD detection task with unlabelled ID data. Extensive experiments demonstrate the superiority of PALM, achieving state-of-the-art average AUROC performance of 93.82 on the challenging CIFAR-100 benchmark. Code is available at https://github.com/jeff024/PALM.

Learning with Mixture of Prototypes for Out-of-Distribution Detection

TL;DR

This paper addresses OOD detection by challenging the common assumption that each ID class can be represented by a single prototype. It introduces PALM, a framework that learns a mixture of class-specific prototypes on a hyperspherical embedding space, using reciprocal-neighbor soft assignments to dynamically allocate samples to multiple prototypes. The approach combines a maximum likelihood loss with a prototype-level contrastive loss and updates prototypes via EMA, enabling accurate ID-OOD discrimination and extending naturally to unsupervised OOD detection. Empirically, PALM achieves state-of-the-art AUROC on CIFAR-100, notable FPR improvements, and strong performance in large-scale and unsupervised settings, with modest computational overhead. These results suggest PALM provides a more faithful representation of data diversity and robust OOD detection in open-world scenarios.

Abstract

Out-of-distribution (OOD) detection aims to detect testing samples far away from the in-distribution (ID) training data, which is crucial for the safe deployment of machine learning models in the real world. Distance-based OOD detection methods have emerged with enhanced deep representation learning. They identify unseen OOD samples by measuring their distances from ID class centroids or prototypes. However, existing approaches learn the representation relying on oversimplified data assumptions, e.g, modeling ID data of each class with one centroid class prototype or using loss functions not designed for OOD detection, which overlook the natural diversities within the data. Naively enforcing data samples of each class to be compact around only one prototype leads to inadequate modeling of realistic data and limited performance. To tackle these issues, we propose PrototypicAl Learning with a Mixture of prototypes (PALM) which models each class with multiple prototypes to capture the sample diversities, and learns more faithful and compact samples embeddings to enhance OOD detection. Our method automatically identifies and dynamically updates prototypes, assigning each sample to a subset of prototypes via reciprocal neighbor soft assignment weights. PALM optimizes a maximum likelihood estimation (MLE) loss to encourage the sample embeddings to be compact around the associated prototypes, as well as a contrastive loss on all prototypes to enhance intra-class compactness and inter-class discrimination at the prototype level. Moreover, the automatic estimation of prototypes enables our approach to be extended to the challenging OOD detection task with unlabelled ID data. Extensive experiments demonstrate the superiority of PALM, achieving state-of-the-art average AUROC performance of 93.82 on the challenging CIFAR-100 benchmark. Code is available at https://github.com/jeff024/PALM.
Paper Structure (21 sections, 15 equations, 9 figures, 12 tables, 1 algorithm)

This paper contains 21 sections, 15 equations, 9 figures, 12 tables, 1 algorithm.

Figures (9)

  • Figure 1: Overview of our proposed framework of prototypical learning with a mixture of prototypes (PALM). We regularize the embedding representation space via the proposed mixture of prototype modeling. We propose (1) optimizing an MLE loss to encourage the sample embeddings to be compact around their associated prototypes, and (2) minimizing prototype contrastive loss to regularize the model at the prototype level. We visualize the calculation of assignment weights on the right.
  • Figure 2: Analysis of embedding quality for CIFAR-100 (ID) using CIDER and PALM. We examine the distance distribution (top), and evaluate compactness and the proportion of far ID samples (bottom).
  • Figure 3: UMAP umap visualization of the first 20 subclasses of ID (CIFAR-100) and all OOD (iSUN) samples plotted to the same embedding space for methods including (a) SSD+ 2021ssd (b) KNN+ sun2022out (c) CIDER ming2023exploit and (d) PALM. The scores are obtained by scaling the distance metrics used by each method to $[0,1]$ for visulization. We measure the area of overlapping sections between ID and OOD scores, as shown in Fig. \ref{['fig:idood_sep']}.
  • Figure 4: Ablation studies on (a) pruning selection, (b) soft vs. hard assignments, (c) number of prototypes for each class and (d) prototype update procedure.
  • Figure 5: Area of overlapping sections between ID and OOD distance densities in percentage. Smaller numbers indicate superior results.
  • ...and 4 more figures