Table of Contents
Fetching ...

ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring

David Berthelot, Nicholas Carlini, Ekin D. Cubuk, Alex Kurakin, Kihyuk Sohn, Han Zhang, Colin Raffel

TL;DR

ReMixMatch enhances MixMatch for semi-supervised learning by introducing distribution alignment and augmentation anchoring, together with a CTAugment-based augmentation strategy. The method aligns the unlabeled-prediction distribution with the ground-truth class distribution and anchors strong augmentations to a weakly augmented anchor, improving stability and performance. Empirically, ReMixMatch achieves state-of-the-art results across CIFAR-10, SVHN, and STL-10 with far fewer labeled examples (e.g., 250 labels on CIFAR-10 matching MixMatch's 4,000-label result) and demonstrates strong few-shot capabilities, with comprehensive ablations validating each component. The work provides open-source code and demonstrates significant data-efficiency gains for SSL in image classification.

Abstract

We improve the recently-proposed "MixMatch" semi-supervised learning algorithm by introducing two new techniques: distribution alignment and augmentation anchoring. Distribution alignment encourages the marginal distribution of predictions on unlabeled data to be close to the marginal distribution of ground-truth labels. Augmentation anchoring feeds multiple strongly augmented versions of an input into the model and encourages each output to be close to the prediction for a weakly-augmented version of the same input. To produce strong augmentations, we propose a variant of AutoAugment which learns the augmentation policy while the model is being trained. Our new algorithm, dubbed ReMixMatch, is significantly more data-efficient than prior work, requiring between $5\times$ and $16\times$ less data to reach the same accuracy. For example, on CIFAR-10 with 250 labeled examples we reach $93.73\%$ accuracy (compared to MixMatch's accuracy of $93.58\%$ with $4{,}000$ examples) and a median accuracy of $84.92\%$ with just four labels per class. We make our code and data open-source at https://github.com/google-research/remixmatch.

ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring

TL;DR

ReMixMatch enhances MixMatch for semi-supervised learning by introducing distribution alignment and augmentation anchoring, together with a CTAugment-based augmentation strategy. The method aligns the unlabeled-prediction distribution with the ground-truth class distribution and anchors strong augmentations to a weakly augmented anchor, improving stability and performance. Empirically, ReMixMatch achieves state-of-the-art results across CIFAR-10, SVHN, and STL-10 with far fewer labeled examples (e.g., 250 labels on CIFAR-10 matching MixMatch's 4,000-label result) and demonstrates strong few-shot capabilities, with comprehensive ablations validating each component. The work provides open-source code and demonstrates significant data-efficiency gains for SSL in image classification.

Abstract

We improve the recently-proposed "MixMatch" semi-supervised learning algorithm by introducing two new techniques: distribution alignment and augmentation anchoring. Distribution alignment encourages the marginal distribution of predictions on unlabeled data to be close to the marginal distribution of ground-truth labels. Augmentation anchoring feeds multiple strongly augmented versions of an input into the model and encourages each output to be close to the prediction for a weakly-augmented version of the same input. To produce strong augmentations, we propose a variant of AutoAugment which learns the augmentation policy while the model is being trained. Our new algorithm, dubbed ReMixMatch, is significantly more data-efficient than prior work, requiring between and less data to reach the same accuracy. For example, on CIFAR-10 with 250 labeled examples we reach accuracy (compared to MixMatch's accuracy of with examples) and a median accuracy of with just four labels per class. We make our code and data open-source at https://github.com/google-research/remixmatch.

Paper Structure

This paper contains 32 sections, 3 equations, 3 figures, 6 tables.

Figures (3)

  • Figure 1: Distribution alignment. Guessed label distributions are adjusted according to the ratio of the empirical ground-truth class distribution divided by the average model predictions on unlabeled data.
  • Figure 2: Augmentation anchoring. We use the prediction for a weakly augmented image (green, middle) as the target for predictions on strong augmentations of the same image (blue).
  • Figure 3: KL divergence between the marginal distribution of model predictions vs. the true marginal distribution of class labels over the course of training with and without distribution alignment. This figure corresponds to a training run on CIFAR-10 with 250 labels.