Table of Contents
Fetching ...

A Theoretical Framework for Preventing Class Collapse in Supervised Contrastive Learning

Chungpa Lee, Jeongheon Oh, Kibok Lee, Jy-yong Sohn

TL;DR

This work tackles class collapse in supervised contrastive learning (SupCL) by introducing the Simplex-to-Simplex Embedding Model (SSEM), a geometric framework that characterizes all optimal embeddings under a convex combination of supervised and self-supervised losses. It proves that any SupCL minimizer lies in SSEM and derives explicit, practically applicable conditions on the loss weight $\alpha$ and temperature $\tau$ to prevent class collapse, linking these to within-class and between-class variances. Theoretical results show SSEM achieves maximal total variance on the unit sphere, and experiments on synthetic and real datasets validate these predictions, revealing that a moderate within-class variance yields the best transfer performance. Overall, the framework provides concrete hyperparameter guidelines and deepens understanding of embedding geometry in SupCL, with clear implications for improving generalization and transfer tasks.

Abstract

Supervised contrastive learning (SupCL) has emerged as a prominent approach in representation learning, leveraging both supervised and self-supervised losses. However, achieving an optimal balance between these losses is challenging; failing to do so can lead to class collapse, reducing discrimination among individual embeddings in the same class. In this paper, we present theoretically grounded guidelines for SupCL to prevent class collapse in learned representations. Specifically, we introduce the Simplex-to-Simplex Embedding Model (SSEM), a theoretical framework that models various embedding structures, including all embeddings that minimize the supervised contrastive loss. Through SSEM, we analyze how hyperparameters affect learned representations, offering practical guidelines for hyperparameter selection to mitigate the risk of class collapse. Our theoretical findings are supported by empirical results across synthetic and real-world datasets.

A Theoretical Framework for Preventing Class Collapse in Supervised Contrastive Learning

TL;DR

This work tackles class collapse in supervised contrastive learning (SupCL) by introducing the Simplex-to-Simplex Embedding Model (SSEM), a geometric framework that characterizes all optimal embeddings under a convex combination of supervised and self-supervised losses. It proves that any SupCL minimizer lies in SSEM and derives explicit, practically applicable conditions on the loss weight and temperature to prevent class collapse, linking these to within-class and between-class variances. Theoretical results show SSEM achieves maximal total variance on the unit sphere, and experiments on synthetic and real datasets validate these predictions, revealing that a moderate within-class variance yields the best transfer performance. Overall, the framework provides concrete hyperparameter guidelines and deepens understanding of embedding geometry in SupCL, with clear implications for improving generalization and transfer tasks.

Abstract

Supervised contrastive learning (SupCL) has emerged as a prominent approach in representation learning, leveraging both supervised and self-supervised losses. However, achieving an optimal balance between these losses is challenging; failing to do so can lead to class collapse, reducing discrimination among individual embeddings in the same class. In this paper, we present theoretically grounded guidelines for SupCL to prevent class collapse in learned representations. Specifically, we introduce the Simplex-to-Simplex Embedding Model (SSEM), a theoretical framework that models various embedding structures, including all embeddings that minimize the supervised contrastive loss. Through SSEM, we analyze how hyperparameters affect learned representations, offering practical guidelines for hyperparameter selection to mitigate the risk of class collapse. Our theoretical findings are supported by empirical results across synthetic and real-world datasets.

Paper Structure

This paper contains 37 sections, 17 theorems, 126 equations, 4 figures, 6 tables.

Key Result

Proposition 4.1

Suppose $mn\geq 2$ and $d\geq mn-1$ hold. Let a set of $mn$ vectors $\{{\bm{w}}_{i,j}\}_{i\in[m], j\in[n]}$ forms the $(mn-1)$-simplex ETF in $\mathbb{R}^{d}$. For a given $\delta \in [0,\sqrt\frac{mn-1}{m(n-1)}]$, define the set of $mn$ vectors ${\bm{U}}^\delta := \{{\bm{u}}_{i,j}^{\delta}\}_{i\in[ where Then, the set of $mn$ vectors ${\bm{U}}^\delta$ constructs ($m,n,\delta$)-SSEM.

Figures (4)

  • Figure 1: Illustration of the proposed Simplex-to-Simplex Embedding Model (SSEM) in Def. \ref{['def:model']}, where both the number of classes ($m$) and the number of instances per class ($n$) are set to 2. The set of embedding vectors in SSEM is denoted by ${\bm{U}} = \{ {\bm{u}}_{1,1}, {\bm{u}}_{1,2}, {\bm{u}}_{2,1}, {\bm{u}}_{2,2} \}$, where the superscript $\delta$ in \ref{['eq:model:entire:set']} is omitted for simplicity. Each embedding's first subscript index indicates its class, with embeddings of class 1 drawn in red and those of class 2 in blue. The embeddings are visualized for different values of $\delta$: (a) When $\delta=0$, SSEM is equal to $1$-simplex ETF, which is when class collapse happens. (b) When $\delta=1$, SSEM is equal to $3$-simplex ETF, where every embedding is equidistant. (c) When $\delta$ varies in the range of $[0, \sqrt{1.5}]$, we visualize the trajectory of ${\bm{u}}_{1,1}$ and ${\bm{u}}_{1,2}$ in the upper arc, and the trajectory of ${\bm{u}}_{2,1}$ and ${\bm{u}}_{2,2}$ in the lower arc, where the color in the trajectory transits from purple to yellow as $\delta$ increases.
  • Figure 2: The within-class variance (averaged over different classes) of the learned embedding set ($\frac{1}{m} \sum_{i\in [m]} \mathrm{Var} [{\bm{U}}_i]$ in \ref{['eq:emb:var:within']}), for various loss-combining coefficient $\alpha$ and temperature $\tau$. (Top): Derived from theoretical results in Sec. \ref{['sec:emb:var']}, (Bottom): Computed from the experiments on synthetic datasets in Sec. \ref{['sec:experiment:synthetic']}. One can confirm that both results (shown at the top and the bottom figures) are well aligned. Here, the red dashed line at the top figure indicates the boundary of regions having zero within-class variance, i.e., when class collapse happens.
  • Figure 3: Average within-class variance of the learned embeddings obtained in theory (lines) and by experiments (dots), measured on CIFAR-10 dataset when ResNet-18 encoder is used. (a), (b): Dependency of the within-class variance on $\alpha$ and $\tau$, for various per-class batch sizes $\tilde{n}=10, 50, 200$. The values obtained in experiments match with those computed from our theoretical results. (c): Relationship between the within-class variance of the learned embeddings and the transfer learning performance (when transferred to CIFAR-100) of the CIFAR-10 trained ResNet-18 encoder. We set $\tau=0.1$ and $\tilde{n}=200$, and run experiments on various $\alpha=0.0, 0.1, \cdots, 1.0$. Note that embeddings having a moderate amount of within-class variances achieves the highest performance.
  • Figure B.1: The within-class variance (averaged over different classes) of the learned embedding set ${\bm{U}}$, for various loss-combining coefficient $\alpha$ and temperature $\tau$. (Top): Computed from theoretical results in Sec. \ref{['sec:emb:var']}, (Bottom): Computed from the experiments on synthetic datasets in Appendix \ref{['sec:appendix:additional:synthetic']}.

Theorems & Definitions (35)

  • Definition 4.1: Simplex ETF
  • Definition 4.2: Simplex-to-Simplex Embedding Model
  • Proposition 4.1: Existence of SSEM
  • Theorem 4.1: Optimality of SSEM
  • Proposition 4.2
  • Remark 1
  • Definition 5.1: Variance of Embeddings
  • Proposition 5.1: Bounded Variance
  • Proposition 5.2: Variance of SSEM
  • Remark 2
  • ...and 25 more