Learning with Mixture of Prototypes for Out-of-Distribution Detection
Haodong Lu, Dong Gong, Shuo Wang, Jason Xue, Lina Yao, Kristen Moore
TL;DR
This paper addresses OOD detection by challenging the common assumption that each ID class can be represented by a single prototype. It introduces PALM, a framework that learns a mixture of class-specific prototypes on a hyperspherical embedding space, using reciprocal-neighbor soft assignments to dynamically allocate samples to multiple prototypes. The approach combines a maximum likelihood loss with a prototype-level contrastive loss and updates prototypes via EMA, enabling accurate ID-OOD discrimination and extending naturally to unsupervised OOD detection. Empirically, PALM achieves state-of-the-art AUROC on CIFAR-100, notable FPR improvements, and strong performance in large-scale and unsupervised settings, with modest computational overhead. These results suggest PALM provides a more faithful representation of data diversity and robust OOD detection in open-world scenarios.
Abstract
Out-of-distribution (OOD) detection aims to detect testing samples far away from the in-distribution (ID) training data, which is crucial for the safe deployment of machine learning models in the real world. Distance-based OOD detection methods have emerged with enhanced deep representation learning. They identify unseen OOD samples by measuring their distances from ID class centroids or prototypes. However, existing approaches learn the representation relying on oversimplified data assumptions, e.g, modeling ID data of each class with one centroid class prototype or using loss functions not designed for OOD detection, which overlook the natural diversities within the data. Naively enforcing data samples of each class to be compact around only one prototype leads to inadequate modeling of realistic data and limited performance. To tackle these issues, we propose PrototypicAl Learning with a Mixture of prototypes (PALM) which models each class with multiple prototypes to capture the sample diversities, and learns more faithful and compact samples embeddings to enhance OOD detection. Our method automatically identifies and dynamically updates prototypes, assigning each sample to a subset of prototypes via reciprocal neighbor soft assignment weights. PALM optimizes a maximum likelihood estimation (MLE) loss to encourage the sample embeddings to be compact around the associated prototypes, as well as a contrastive loss on all prototypes to enhance intra-class compactness and inter-class discrimination at the prototype level. Moreover, the automatic estimation of prototypes enables our approach to be extended to the challenging OOD detection task with unlabelled ID data. Extensive experiments demonstrate the superiority of PALM, achieving state-of-the-art average AUROC performance of 93.82 on the challenging CIFAR-100 benchmark. Code is available at https://github.com/jeff024/PALM.
