Table of Contents
Fetching ...

Context-Enriched Contrastive Loss: Enhancing Presentation of Inherent Sample Connections in Contrastive Learning Framework

Haojin Deng, Yimin Yang

TL;DR

ConTeX addresses two core issues in contrastive learning: distortion from augmentations and slow convergence, by introducing a context-enriched loss with two complementary components that separately leverage context-based contrasts and self-positives. The method yields faster convergence and improved generalization, with strong fairness gains on bias-related tasks such as BiasedMNIST, UTKFace, and CelebA. Extensive experiments across CIFAR-10/100, ImageNet, and transfer settings show competitive or superior performance versus state-of-the-art contrastive losses, including SupCon, with substantial bias mitigation. The work demonstrates ConTeX's potential for efficient, fair downstream training and provides theoretical insights via an upper-bound analysis and gradient-level justification.

Abstract

Contrastive learning has gained popularity and pushes state-of-the-art performance across numerous large-scale benchmarks. In contrastive learning, the contrastive loss function plays a pivotal role in discerning similarities between samples through techniques such as rotation or cropping. However, this learning mechanism can also introduce information distortion from the augmented samples. This is because the trained model may develop a significant overreliance on information from samples with identical labels, while concurrently neglecting positive pairs that originate from the same initial image, especially in expansive datasets. This paper proposes a context-enriched contrastive loss function that concurrently improves learning effectiveness and addresses the information distortion by encompassing two convergence targets. The first component, which is notably sensitive to label contrast, differentiates between features of identical and distinct classes which boosts the contrastive training efficiency. Meanwhile, the second component draws closer the augmented samples from the same source image and distances all other samples. We evaluate the proposed approach on image classification tasks, which are among the most widely accepted 8 recognition large-scale benchmark datasets: CIFAR10, CIFAR100, Caltech-101, Caltech-256, ImageNet, BiasedMNIST, UTKFace, and CelebA datasets. The experimental results demonstrate that the proposed method achieves improvements over 16 state-of-the-art contrastive learning methods in terms of both generalization performance and learning convergence speed. Interestingly, our technique stands out in addressing systematic distortion tasks. It demonstrates a 22.9% improvement compared to original contrastive loss functions in the downstream BiasedMNIST dataset, highlighting its promise for more efficient and equitable downstream training.

Context-Enriched Contrastive Loss: Enhancing Presentation of Inherent Sample Connections in Contrastive Learning Framework

TL;DR

ConTeX addresses two core issues in contrastive learning: distortion from augmentations and slow convergence, by introducing a context-enriched loss with two complementary components that separately leverage context-based contrasts and self-positives. The method yields faster convergence and improved generalization, with strong fairness gains on bias-related tasks such as BiasedMNIST, UTKFace, and CelebA. Extensive experiments across CIFAR-10/100, ImageNet, and transfer settings show competitive or superior performance versus state-of-the-art contrastive losses, including SupCon, with substantial bias mitigation. The work demonstrates ConTeX's potential for efficient, fair downstream training and provides theoretical insights via an upper-bound analysis and gradient-level justification.

Abstract

Contrastive learning has gained popularity and pushes state-of-the-art performance across numerous large-scale benchmarks. In contrastive learning, the contrastive loss function plays a pivotal role in discerning similarities between samples through techniques such as rotation or cropping. However, this learning mechanism can also introduce information distortion from the augmented samples. This is because the trained model may develop a significant overreliance on information from samples with identical labels, while concurrently neglecting positive pairs that originate from the same initial image, especially in expansive datasets. This paper proposes a context-enriched contrastive loss function that concurrently improves learning effectiveness and addresses the information distortion by encompassing two convergence targets. The first component, which is notably sensitive to label contrast, differentiates between features of identical and distinct classes which boosts the contrastive training efficiency. Meanwhile, the second component draws closer the augmented samples from the same source image and distances all other samples. We evaluate the proposed approach on image classification tasks, which are among the most widely accepted 8 recognition large-scale benchmark datasets: CIFAR10, CIFAR100, Caltech-101, Caltech-256, ImageNet, BiasedMNIST, UTKFace, and CelebA datasets. The experimental results demonstrate that the proposed method achieves improvements over 16 state-of-the-art contrastive learning methods in terms of both generalization performance and learning convergence speed. Interestingly, our technique stands out in addressing systematic distortion tasks. It demonstrates a 22.9% improvement compared to original contrastive loss functions in the downstream BiasedMNIST dataset, highlighting its promise for more efficient and equitable downstream training.

Paper Structure

This paper contains 33 sections, 1 theorem, 30 equations, 7 figures, 9 tables.

Key Result

Lemma 1

At least two positive pairs correspond to one representation in a mini-batch exist if the batch size $n$ is larger than the number of label classes $m$.

Figures (7)

  • Figure 1: Random augmented images from ImageNet dataset. All these images are from the 'n04536866' category. The images in the first row are the original images. The images in the same column are from the same source image.
  • Figure 2: Our proposed loss function vs previous contrastive loss functions in latent space. The blue dot is the current anchor. The red dots are context negatives (samples have different labels or context with anchor). The yellow dots are context positives (samples have the same label or similar context with anchor). The light cyan dot is the self positive (the augmented sample from the same original sample as the anchor). The first part of our loss function is similar to (b), but we enhanced the contrastive between the context positives and context negatives (Eq. \ref{['loss part 1']}). The second part of our loss function (Eq. \ref{['loss part 2']}) has a similar proposal as (a). Our combined loss function created two stable boundaries to maximum similarity only from self positive sample to avoid misleading by labels.
  • Figure 3: The framework of our stage one training. The images with light blue frame are the anchor and its self positive (the augmented sample from the same original sample as the anchor). The images with red frame are context negatives (samples have different labels with anchor). The images with yellow frame are context positives (samples have the same label with anchor). Both the images with the light cyan frame and yellow frame belong to the 'Violin' class and the images with the red frame are from the 'Ocarina' class. In the linear evaluation and fine-tuning stage, we only use the ResNet-50 encoder from stage one and add a linear fully connected layer as a classifier. The output dimension of ResNet-50 is 2048, and each layer of the projection head block contains 128 neurons.
  • Figure 4: Comparison of unbiased accuracy on the BiasedMNIST bmnist dataset with high target-bias correlations.
  • Figure 5: Finetuning evaluation accuracy of transfer learning between CIFAR10 and CIFAR100. The blue line represents the ConTeX loss function and the orange line represents the SupCon loss function. We first pretrained the ResNet-50 by both the loss function and both datasets (CIFAR10 and CIFAR100) for 1000 epochs. Then we finetune the network with the other dataset for 100 epochs.
  • ...and 2 more figures

Theorems & Definitions (2)

  • Lemma 1
  • proof