Table of Contents
Fetching ...

EMC$^2$: Efficient MCMC Negative Sampling for Contrastive Learning with Global Convergence

Chung-Yiu Yau, Hoi-To Wai, Parameswaran Raman, Soumajyoti Sarkar, Mingyi Hong

TL;DR

EMC$^2$ addresses the cost of sampling a large set of negatives in contrastive learning by coupling online Metropolis-Hastings sampling with state-dependent SGD. It proves global convergence to a stationary point of the global contrastive loss at rate $O(1/\\sqrt{T})$ and shows that this holds independent of batch size and burn-in, while reducing memory and computation relative to prior methods. Theoretical results establish geometric ergodicity of the MCMC components and Lipschitz smoothness of the state-dependent kernel, enabling biased stochastic approximation analysis. Empirical results on STL-10 and Imagenet-100 demonstrate that EMC$^2$ enables efficient small-batch pre-training with competitive LP and 1-NN performance compared to strong baselines.

Abstract

A key challenge in contrastive learning is to generate negative samples from a large sample set to contrast with positive samples, for learning better encoding of the data. These negative samples often follow a softmax distribution which are dynamically updated during the training process. However, sampling from this distribution is non-trivial due to the high computational costs in computing the partition function. In this paper, we propose an Efficient Markov Chain Monte Carlo negative sampling method for Contrastive learning (EMC$^2$). We follow the global contrastive learning loss as introduced in SogCLR, and propose EMC$^2$ which utilizes an adaptive Metropolis-Hastings subroutine to generate hardness-aware negative samples in an online fashion during the optimization. We prove that EMC$^2$ finds an $\mathcal{O}(1/\sqrt{T})$-stationary point of the global contrastive loss in $T$ iterations. Compared to prior works, EMC$^2$ is the first algorithm that exhibits global convergence (to stationarity) regardless of the choice of batch size while exhibiting low computation and memory cost. Numerical experiments validate that EMC$^2$ is effective with small batch training and achieves comparable or better performance than baseline algorithms. We report the results for pre-training image encoders on STL-10 and Imagenet-100.

EMC$^2$: Efficient MCMC Negative Sampling for Contrastive Learning with Global Convergence

TL;DR

EMC addresses the cost of sampling a large set of negatives in contrastive learning by coupling online Metropolis-Hastings sampling with state-dependent SGD. It proves global convergence to a stationary point of the global contrastive loss at rate and shows that this holds independent of batch size and burn-in, while reducing memory and computation relative to prior methods. Theoretical results establish geometric ergodicity of the MCMC components and Lipschitz smoothness of the state-dependent kernel, enabling biased stochastic approximation analysis. Empirical results on STL-10 and Imagenet-100 demonstrate that EMC enables efficient small-batch pre-training with competitive LP and 1-NN performance compared to strong baselines.

Abstract

A key challenge in contrastive learning is to generate negative samples from a large sample set to contrast with positive samples, for learning better encoding of the data. These negative samples often follow a softmax distribution which are dynamically updated during the training process. However, sampling from this distribution is non-trivial due to the high computational costs in computing the partition function. In this paper, we propose an Efficient Markov Chain Monte Carlo negative sampling method for Contrastive learning (EMC). We follow the global contrastive learning loss as introduced in SogCLR, and propose EMC which utilizes an adaptive Metropolis-Hastings subroutine to generate hardness-aware negative samples in an online fashion during the optimization. We prove that EMC finds an -stationary point of the global contrastive loss in iterations. Compared to prior works, EMC is the first algorithm that exhibits global convergence (to stationarity) regardless of the choice of batch size while exhibiting low computation and memory cost. Numerical experiments validate that EMC is effective with small batch training and achieves comparable or better performance than baseline algorithms. We report the results for pre-training image encoders on STL-10 and Imagenet-100.
Paper Structure (26 sections, 4 theorems, 38 equations, 8 figures, 3 tables, 1 algorithm)

This paper contains 26 sections, 4 theorems, 38 equations, 8 figures, 3 tables, 1 algorithm.

Key Result

Lemma 3.2

Under Assumption assm:bounded_embd. For any $\theta \in \mathbb{R}^p$ and any initialization $\tilde{\bm{\xi}}_0 \in \Xi$, the Markov chain $\tilde{\bm{\xi}}_0 \to \tilde{\bm{\xi}}_1 \to \cdots$ induced by the transition kernel ${\tt P}_{\theta}$ converges geometrically to the stationary distributio for any ${\bm z}$ and any $\tau \geq 0$.

Figures (8)

  • Figure 1: Training 100 epochs on STL-10 with ResNet-18 using batch size $b = 32$. Horizontal axis is relative to the wall-clock training time in seconds.
  • Figure 2: Illustration of mini-batch MCMC sampling with 2 augmentations $(x_i', x_i")$ of each image $x_i$. Each horizontal arrow represents a distribution tracked by one Markov chain and the direction of M-H reject/accept step. Shaded area represents the samples used for burn-in with burn-in period $P<2b-2$. Crossed-out diagonals are not regarded as negative samples.
  • Figure 3: Comparison between different sizes of pre-augmented STL-10 with ResNet-18 and batch size $b = 256$. Horizontal axis is relative to the number of samples accessed.
  • Figure 4: Comparison between different numbers of burn-in negative samples $P$ for each Markov chain state $Z_i$.
  • Figure 5: Comparison on STL-10 with ResNet-18 using batch size (top) $b = 32$, (bottom) $b = 256$. Horizontal axis is relative to the number of samples accessed.
  • ...and 3 more figures

Theorems & Definitions (4)

  • Lemma 3.2
  • Lemma 3.4
  • Theorem 3.7
  • Corollary 3.8