Variational Supervised Contrastive Learning
Ziwen Wang, Jiajun Fan, Thao Nguyen, Heng Ji, Ge Liu
TL;DR
This work reframes supervised contrastive learning as variational inference over latent class variables, yielding a posterior-weighted ELBO that uses class centroids and a confidence-aware temperature to regulate intra-class dispersion. The VarCon objective couples distribution alignment (via KL divergence) with predictable classification likelihood, enabling efficient, near-linear training in batch size and improved semantic structure in embeddings. Empirical results across CIFAR and ImageNet benchmarks show state-of-the-art Top-1 accuracy, faster convergence, stronger few-shot and transfer performance, and robustness to augmentation strategies and corruption. The approach elegantly bridges discriminative and generative perspectives by endowing contrastive learning with explicit probabilistic semantics and uncertainty modeling.
Abstract
Contrastive learning has proven to be highly efficient and adaptable in shaping representation spaces across diverse modalities by pulling similar samples together and pushing dissimilar ones apart. However, two key limitations persist: (1) Without explicit regulation of the embedding distribution, semantically related instances can inadvertently be pushed apart unless complementary signals guide pair selection, and (2) excessive reliance on large in-batch negatives and tailored augmentations hinders generalization. To address these limitations, we propose Variational Supervised Contrastive Learning (VarCon), which reformulates supervised contrastive learning as variational inference over latent class variables and maximizes a posterior-weighted evidence lower bound (ELBO) that replaces exhaustive pair-wise comparisons for efficient class-aware matching and grants fine-grained control over intra-class dispersion in the embedding space. Trained exclusively on image data, our experiments on CIFAR-10, CIFAR-100, ImageNet-100, and ImageNet-1K show that VarCon (1) achieves state-of-the-art performance for contrastive learning frameworks, reaching 79.36% Top-1 accuracy on ImageNet-1K and 78.29% on CIFAR-100 with a ResNet-50 encoder while converging in just 200 epochs; (2) yields substantially clearer decision boundaries and semantic organization in the embedding space, as evidenced by KNN classification, hierarchical clustering results, and transfer-learning assessments; and (3) demonstrates superior performance in few-shot learning than supervised baseline and superior robustness across various augmentation strategies. Our code is available at https://github.com/ziwenwang28/VarContrast.
