Table of Contents
Fetching ...

On Improving the Algorithm-, Model-, and Data- Efficiency of Self-Supervised Learning

Yun-Hao Cao, Jianxin Wu

TL;DR

This paper tackles the inefficiency of self-supervised learning by proposing a single-branch, non-parametric instance-discrimination framework that uses a gradient-corrected memory bank and a SqrtKL self-distillation loss to accelerate convergence. It initializes the memory bank via a forward pass on an untrained network, applies a corrected update rule that propagates updates beyond the target instance, and introduces $L_{SqrtKL}$ to balance gradient flow across classes. Empirically, the method achieves favorable accuracy with substantially lower training time and memory usage, excelling on small models and limited data, and showing strong data efficiency on ImageNet subsets and robust transfer performance. While competitive on large-scale linear evaluation, the authors emphasize efficiency and applicability in resource-constrained settings and outline future work to scale to larger datasets.

Abstract

Self-supervised learning (SSL) has developed rapidly in recent years. However, most of the mainstream methods are computationally expensive and rely on two (or more) augmentations for each image to construct positive pairs. Moreover, they mainly focus on large models and large-scale datasets, which lack flexibility and feasibility in many practical applications. In this paper, we propose an efficient single-branch SSL method based on non-parametric instance discrimination, aiming to improve the algorithm, model, and data efficiency of SSL. By analyzing the gradient formula, we correct the update rule of the memory bank with improved performance. We further propose a novel self-distillation loss that minimizes the KL divergence between the probability distribution and its square root version. We show that this alleviates the infrequent updating problem in instance discrimination and greatly accelerates convergence. We systematically compare the training overhead and performance of different methods in different scales of data, and under different backbones. Experimental results show that our method outperforms various baselines with significantly less overhead, and is especially effective for limited amounts of data and small models.

On Improving the Algorithm-, Model-, and Data- Efficiency of Self-Supervised Learning

TL;DR

This paper tackles the inefficiency of self-supervised learning by proposing a single-branch, non-parametric instance-discrimination framework that uses a gradient-corrected memory bank and a SqrtKL self-distillation loss to accelerate convergence. It initializes the memory bank via a forward pass on an untrained network, applies a corrected update rule that propagates updates beyond the target instance, and introduces to balance gradient flow across classes. Empirically, the method achieves favorable accuracy with substantially lower training time and memory usage, excelling on small models and limited data, and showing strong data efficiency on ImageNet subsets and robust transfer performance. While competitive on large-scale linear evaluation, the authors emphasize efficiency and applicability in resource-constrained settings and outline future work to scale to larger datasets.

Abstract

Self-supervised learning (SSL) has developed rapidly in recent years. However, most of the mainstream methods are computationally expensive and rely on two (or more) augmentations for each image to construct positive pairs. Moreover, they mainly focus on large models and large-scale datasets, which lack flexibility and feasibility in many practical applications. In this paper, we propose an efficient single-branch SSL method based on non-parametric instance discrimination, aiming to improve the algorithm, model, and data efficiency of SSL. By analyzing the gradient formula, we correct the update rule of the memory bank with improved performance. We further propose a novel self-distillation loss that minimizes the KL divergence between the probability distribution and its square root version. We show that this alleviates the infrequent updating problem in instance discrimination and greatly accelerates convergence. We systematically compare the training overhead and performance of different methods in different scales of data, and under different backbones. Experimental results show that our method outperforms various baselines with significantly less overhead, and is especially effective for limited amounts of data and small models.
Paper Structure (19 sections, 21 equations, 8 figures, 9 tables)

This paper contains 19 sections, 21 equations, 8 figures, 9 tables.

Figures (8)

  • Figure 1: Linear probing accuracy and training cost (in hours) of different SSL methods on CIFAR-100 cifar.
  • Figure 2: The general framework of our method.
  • Figure 3: Our method with vs. without SqrtKL on CIFAR-10.
  • Figure 4: Effect of $m$ on CIFAR-10 under ResNet-18.
  • Figure 5: Effect of $\lambda$ on CIFAR-10 under ResNet-18.
  • ...and 3 more figures