Table of Contents
Fetching ...

A Tale of Two Classes: Adapting Supervised Contrastive Learning to Binary Imbalanced Datasets

David Mildenberger, Paul Hager, Daniel Rueckert, Martin J Menten

TL;DR

This work demonstrates that SupCon, while powerful for balanced multi-class problems, exhibits severe representation-collapse and reduced downstream utility on binary imbalanced datasets. It introduces two targeted fixes—Supervised Minority and Supervised Prototypes—and two diagnostic metrics, SAA and CAC, to detect and quantify the collapse, which canonical metrics fail to reveal. The fixes yield up to a 35% gain in downstream accuracy over SupCon and outperform leading long-tailed methods by up to 5% in several medical and natural imaging tasks, with minimal computational overhead. Together with theoretical insights and extensive ablations, the paper provides practical strategies and diagnostics to extend SupCon to prevalent binary-imbalance scenarios, including medical imaging.

Abstract

Supervised contrastive learning (SupCon) has proven to be a powerful alternative to the standard cross-entropy loss for classification of multi-class balanced datasets. However, it struggles to learn well-conditioned representations of datasets with long-tailed class distributions. This problem is potentially exacerbated for binary imbalanced distributions, which are commonly encountered during many real-world problems such as medical diagnosis. In experiments on seven binary datasets of natural and medical images, we show that the performance of SupCon decreases with increasing class imbalance. To substantiate these findings, we introduce two novel metrics that evaluate the quality of the learned representation space. By measuring the class distribution in local neighborhoods, we are able to uncover structural deficiencies of the representation space that classical metrics cannot detect. Informed by these insights, we propose two new supervised contrastive learning strategies tailored to binary imbalanced datasets that improve the structure of the representation space and increase downstream classification accuracy over standard SupCon by up to 35%. We make our code available.

A Tale of Two Classes: Adapting Supervised Contrastive Learning to Binary Imbalanced Datasets

TL;DR

This work demonstrates that SupCon, while powerful for balanced multi-class problems, exhibits severe representation-collapse and reduced downstream utility on binary imbalanced datasets. It introduces two targeted fixes—Supervised Minority and Supervised Prototypes—and two diagnostic metrics, SAA and CAC, to detect and quantify the collapse, which canonical metrics fail to reveal. The fixes yield up to a 35% gain in downstream accuracy over SupCon and outperform leading long-tailed methods by up to 5% in several medical and natural imaging tasks, with minimal computational overhead. Together with theoretical insights and extensive ablations, the paper provides practical strategies and diagnostics to extend SupCon to prevalent binary-imbalance scenarios, including medical imaging.

Abstract

Supervised contrastive learning (SupCon) has proven to be a powerful alternative to the standard cross-entropy loss for classification of multi-class balanced datasets. However, it struggles to learn well-conditioned representations of datasets with long-tailed class distributions. This problem is potentially exacerbated for binary imbalanced distributions, which are commonly encountered during many real-world problems such as medical diagnosis. In experiments on seven binary datasets of natural and medical images, we show that the performance of SupCon decreases with increasing class imbalance. To substantiate these findings, we introduce two novel metrics that evaluate the quality of the learned representation space. By measuring the class distribution in local neighborhoods, we are able to uncover structural deficiencies of the representation space that classical metrics cannot detect. Informed by these insights, we propose two new supervised contrastive learning strategies tailored to binary imbalanced datasets that improve the structure of the representation space and increase downstream classification accuracy over standard SupCon by up to 35%. We make our code available.

Paper Structure

This paper contains 67 sections, 1 theorem, 28 equations, 20 figures, 7 tables.

Key Result

Lemma 1

Let $z_i,z_k \in \mathcal{S}^{128}$ be two projections of an uninitialized ResNet50 model and an uninitialized Projection layer. Then, For a small $\varepsilon \in \mathbb{R}$, $\| z_i - z_k \| \leq \varepsilon$

Figures (20)

  • Figure 1: Supervised contrastive learning (SupCon) on multi-class balanced datasets returns a well-conditioned representation space, in which semantic classes are clearly separated. We show that for binary imbalanced datasets the prevalence of a dominant majority class causes the embeddings to collapse to a single point. Our proposed fixes restore the clear separation of semantic classes.
  • Figure 2: Our novel sample alignment accuracy (SAA) and class alignment consistency (CAC) metrics capture the relationships between embeddings of different classes instead of just within one class. By more directly measuring the separability of latent classes, it is a stronger indicator of a representation space's downstream utility.
  • Figure 3: We introduce two fixes for supervised contrastive learning. Supervised Minority applies supervision exclusively to the minority class, preventing class collapse and enhancing alignment of the minority class. Supervised Prototypes attracts samples to fixed class prototypes, improving both class alignment and uniformity.
  • Figure 4: Boxplots of metrics analysing SupCon's representation space learned from the plants dataset. As class imbalance grows the representation space collapses despite the canonical SAD and CAD metrics being low. In contrast, SAA and CAC correctly identify the collapse. Similar results are observed on the insects and animals datasets (see supplementary \ref{['app:align_unif']}).
  • Figure 5: Correlations between five representation-space metrics and linear probing performance across all datasets and all considered methods. The overall $R^2$ is calculated globally over all points while the dataset mean $R^2$ is calculated per dataset and then averaged. As SAA and CAC are the only metrics that account for relationships between samples and classes instead of simply within them, they correlate much stronger with downstream performance.
  • ...and 15 more figures

Theorems & Definitions (1)

  • Lemma 1