A Theoretical Framework for Preventing Class Collapse in Supervised Contrastive Learning
Chungpa Lee, Jeongheon Oh, Kibok Lee, Jy-yong Sohn
TL;DR
This work tackles class collapse in supervised contrastive learning (SupCL) by introducing the Simplex-to-Simplex Embedding Model (SSEM), a geometric framework that characterizes all optimal embeddings under a convex combination of supervised and self-supervised losses. It proves that any SupCL minimizer lies in SSEM and derives explicit, practically applicable conditions on the loss weight $\alpha$ and temperature $\tau$ to prevent class collapse, linking these to within-class and between-class variances. Theoretical results show SSEM achieves maximal total variance on the unit sphere, and experiments on synthetic and real datasets validate these predictions, revealing that a moderate within-class variance yields the best transfer performance. Overall, the framework provides concrete hyperparameter guidelines and deepens understanding of embedding geometry in SupCL, with clear implications for improving generalization and transfer tasks.
Abstract
Supervised contrastive learning (SupCL) has emerged as a prominent approach in representation learning, leveraging both supervised and self-supervised losses. However, achieving an optimal balance between these losses is challenging; failing to do so can lead to class collapse, reducing discrimination among individual embeddings in the same class. In this paper, we present theoretically grounded guidelines for SupCL to prevent class collapse in learned representations. Specifically, we introduce the Simplex-to-Simplex Embedding Model (SSEM), a theoretical framework that models various embedding structures, including all embeddings that minimize the supervised contrastive loss. Through SSEM, we analyze how hyperparameters affect learned representations, offering practical guidelines for hyperparameter selection to mitigate the risk of class collapse. Our theoretical findings are supported by empirical results across synthetic and real-world datasets.
