Table of Contents
Fetching ...

Diverse Prototypical Ensembles Improve Robustness to Subpopulation Shift

Minh Nguyen Nhat To, Paul F RWilson, Viet Nguyen, Mohamed Harmanani, Michael Cooper, Fahimeh Fooladgar, Purang Abolmaesumi, Parvin Mousavi, Rahul G. Krishnan

TL;DR

This paper tackles subpopulation shift by introducing Diverse Prototypical Ensembles (DPE), a two-stage method that keeps a fixed feature extractor and replaces the classifier with a diversified, distance-based prototypical ensemble. By jointly training multiple prototypes per class with explicit inter-prototype similarity regularization and bootstrap diversification, DPE discovers and covers diverse subpopulations without requiring subgroup annotations. Empirical results across nine real-world benchmarks show that DPE improves worst-group accuracy, often surpassing state-of-the-art methods, while maintaining competitive standard accuracy. The approach offers a scalable, annotation-free path to fairness in deployment settings where subgroup labels are unavailable or costly to obtain.

Abstract

The subpopulationtion shift, characterized by a disparity in subpopulation distributibetween theween the training and target datasets, can significantly degrade the performance of machine learning models. Current solutions to subpopulation shift involve modifying empirical risk minimization with re-weighting strategies to improve generalization. This strategy relies on assumptions about the number and nature of subpopulations and annotations on group membership, which are unavailable for many real-world datasets. Instead, we propose using an ensemble of diverse classifiers to adaptively capture risk associated with subpopulations. Given a feature extractor network, we replace its standard linear classification layer with a mixture of prototypical classifiers, where each member is trained to classify the data while focusing on different features and samples from other members. In empirical evaluation on nine real-world datasets, covering diverse domains and kinds of subpopulation shift, our method of Diverse Prototypical Ensembles (DPEs) often outperforms the prior state-of-the-art in worst-group accuracy. The code is available at https://github.com/minhto2802/dpe4subpop

Diverse Prototypical Ensembles Improve Robustness to Subpopulation Shift

TL;DR

This paper tackles subpopulation shift by introducing Diverse Prototypical Ensembles (DPE), a two-stage method that keeps a fixed feature extractor and replaces the classifier with a diversified, distance-based prototypical ensemble. By jointly training multiple prototypes per class with explicit inter-prototype similarity regularization and bootstrap diversification, DPE discovers and covers diverse subpopulations without requiring subgroup annotations. Empirical results across nine real-world benchmarks show that DPE improves worst-group accuracy, often surpassing state-of-the-art methods, while maintaining competitive standard accuracy. The approach offers a scalable, annotation-free path to fairness in deployment settings where subgroup labels are unavailable or costly to obtain.

Abstract

The subpopulationtion shift, characterized by a disparity in subpopulation distributibetween theween the training and target datasets, can significantly degrade the performance of machine learning models. Current solutions to subpopulation shift involve modifying empirical risk minimization with re-weighting strategies to improve generalization. This strategy relies on assumptions about the number and nature of subpopulations and annotations on group membership, which are unavailable for many real-world datasets. Instead, we propose using an ensemble of diverse classifiers to adaptively capture risk associated with subpopulations. Given a feature extractor network, we replace its standard linear classification layer with a mixture of prototypical classifiers, where each member is trained to classify the data while focusing on different features and samples from other members. In empirical evaluation on nine real-world datasets, covering diverse domains and kinds of subpopulation shift, our method of Diverse Prototypical Ensembles (DPEs) often outperforms the prior state-of-the-art in worst-group accuracy. The code is available at https://github.com/minhto2802/dpe4subpop

Paper Structure

This paper contains 46 sections, 10 equations, 15 figures, 18 tables, 1 algorithm.

Figures (15)

  • Figure 1: High-level overview of our method. (1) Binary classification with implicit (unannotated) subgroups. We aim to natively detect and correct for subpopulation shifts without prior subgroup knowledge. (2) Given a frozen feature extractor, $f(\cdot)$, we train (3) an ensemble of $N$ prototype classifiers for each of the $K$ classes to identify distinct sub-groups. These classifiers are trained using $\mathcal{L}_{\text{IPS}}$ (Equation \ref{['eqn:ips-loss']}) to maximize prototype diversity, ensuring robust subpopulation capture within each class. (4) A low-dimensional projection of the centroids and proximal images for class "Landbird" in Waterbirds. The learned centroids for each ensemble member reveal unique latent subpopulations. Points closest to each centroid appear in blue, while points farther away are in red. The closest few points are shown in dark blue, with corresponding images visualized in (5). (5) Visualization confirms DPE's ability to capture salient subgroups. We have manually annotated the theme associated with each learned prototype centroid. The closest points to each centroid exhibit thematic consistency, aligning with implicit data subgroups (e.g., birds "on land" vs. "in water").
  • Figure 1: Class/Attribute distribution in Waterbirds dataset.
  • Figure 2: Motivation of DPE. (a) The synthetic training data consists of two classes, with major subgroups containing Attribute 1 and minority subgroups containing Attributes 2 and 3. Training a single model on the entire dataset leads to suboptimal decision boundaries, focusing primarily on the major subgroups; (b, c, d) as the number of models in the prototypical ensemble increases, where each member is trained to classify based on a distinct attribute, decision boundaries become more refined, improving classification across subpopulations.
  • Figure 2: Class/Attribute distribution in CelebA dataset.
  • Figure 3: Worst-group improvement over ERM* when using DPE with and without subgroup annotations.
  • ...and 10 more figures