Table of Contents
Fetching ...

NorMatch: Matching Normalizing Flows with Discriminative Classifiers for Semi-Supervised Learning

Zhongying Deng, Rihuan Ke, Carola-Bibiane Schonlieb, Angelica I Aviles-Rivero

TL;DR

Semi-supervised learning with few labels suffers from noisy pseudo-labels and confirmation bias. The authors propose NorMatch, which pairs a discriminative classifier with a Normalizing Flow Classifier (NFC) to estimate uncertainty via consensus (NCUE) and to model the unlabeled data distribution (NUM), all trained with a joint objective that includes a threshold-free sample weighting strategy. The training losses combine supervised components on labeled data and weighted unsupervised terms on unlabeled data, with gradients to the NFC but not the backbone during inference. Empirical results on CIFAR-10/100, STL-10, and Mini-ImageNet show NorMatch achieving state-of-the-art or competitive performance across several label regimes, while acknowledging the additional computational cost and sensitivity to the underlying SSL baseline.

Abstract

Semi-Supervised Learning (SSL) aims to learn a model using a tiny labeled set and massive amounts of unlabeled data. To better exploit the unlabeled data the latest SSL methods use pseudo-labels predicted from a single discriminative classifier. However, the generated pseudo-labels are inevitably linked to inherent confirmation bias and noise which greatly affects the model performance. In this work we introduce a new framework for SSL named NorMatch. Firstly, we introduce a new uncertainty estimation scheme based on normalizing flows, as an auxiliary classifier, to enforce highly certain pseudo-labels yielding a boost of the discriminative classifiers. Secondly, we introduce a threshold-free sample weighting strategy to exploit better both high and low confidence pseudo-labels. Furthermore, we utilize normalizing flows to model, in an unsupervised fashion, the distribution of unlabeled data. This modelling assumption can further improve the performance of generative classifiers via unlabeled data, and thus, implicitly contributing to training a better discriminative classifier. We demonstrate, through numerical and visual results, that NorMatch achieves state-of-the-art performance on several datasets.

NorMatch: Matching Normalizing Flows with Discriminative Classifiers for Semi-Supervised Learning

TL;DR

Semi-supervised learning with few labels suffers from noisy pseudo-labels and confirmation bias. The authors propose NorMatch, which pairs a discriminative classifier with a Normalizing Flow Classifier (NFC) to estimate uncertainty via consensus (NCUE) and to model the unlabeled data distribution (NUM), all trained with a joint objective that includes a threshold-free sample weighting strategy. The training losses combine supervised components on labeled data and weighted unsupervised terms on unlabeled data, with gradients to the NFC but not the backbone during inference. Empirical results on CIFAR-10/100, STL-10, and Mini-ImageNet show NorMatch achieving state-of-the-art or competitive performance across several label regimes, while acknowledging the additional computational cost and sensitivity to the underlying SSL baseline.

Abstract

Semi-Supervised Learning (SSL) aims to learn a model using a tiny labeled set and massive amounts of unlabeled data. To better exploit the unlabeled data the latest SSL methods use pseudo-labels predicted from a single discriminative classifier. However, the generated pseudo-labels are inevitably linked to inherent confirmation bias and noise which greatly affects the model performance. In this work we introduce a new framework for SSL named NorMatch. Firstly, we introduce a new uncertainty estimation scheme based on normalizing flows, as an auxiliary classifier, to enforce highly certain pseudo-labels yielding a boost of the discriminative classifiers. Secondly, we introduce a threshold-free sample weighting strategy to exploit better both high and low confidence pseudo-labels. Furthermore, we utilize normalizing flows to model, in an unsupervised fashion, the distribution of unlabeled data. This modelling assumption can further improve the performance of generative classifiers via unlabeled data, and thus, implicitly contributing to training a better discriminative classifier. We demonstrate, through numerical and visual results, that NorMatch achieves state-of-the-art performance on several datasets.
Paper Structure (16 sections, 10 equations, 6 figures, 6 tables)

This paper contains 16 sections, 10 equations, 6 figures, 6 tables.

Figures (6)

  • Figure 1: The comparison of a) discriminative classifier, b) normalizing flow classifier (NFC) izmailov2020semi, and c) discriminative + normalizing flow classifiers in predicting unlabeled data. Here, the data points in all the sub-plots are of the same set of inputs. Our goal is to predict highly certain pseudo-labels for the unlabeled samples (the red dots). Inconsistent predictions from the discriminative classifier and NFC on a sample (e.g., the left red dot) indicate that the pseudo-label is less trustworthy (e.g., the uncertain region in c)). In this case, we will downplay its importance to avoid over-confidence. In contrast, if consistent predictions (e.g., on the right red dot) are achieved among both classifiers, the pseudo labels have higher certainty. Ideally, if the predictions are consistent under any hypothesises, i.e., any different classifiers, we can fully trust the predicted pseudo-labels.
  • Figure 2: The overview of our NorMatch for unlabeled data. The modules with the same colors (i.e., the CNN and discriminative classifier) share the same set of parameters. NorMatch uses the shared CNN backbone to extract the features of weakly- and strongly-augmented versions of the same unlabeled sample. Then the weakly-augmented features are input to the Normalizing Flow Classifier (NFC) for Unsupervised Modeling (NUM) by likelihood maximization, and to the discriminative classifier to obtain the pseudo-labels. These features are also input to the NFC and the discriminative classifier to enforce a Consensus Uncertainty Estimation (called NCUE). The NCUE generates weights for each sample/pseudo-label, which highlights the consistent predictions and downplays the disagreed ones. The weights together with pseudo-labels are then used to enforce a weighted cross-entropy, which supervises the training for the strongly-augmented version.
  • Figure 3: The amount of low-uncertainty (high-confidence) pseudo-labels obtained by NCUE (the blue curve) and threshold-based FixMatch (the green curve with the threshold set to 0.95 as in sohn2020fixmatch), respectively. The red curve denotes the number of correct pseudo-labels measured by using ground-truth labels. The x-axis represents the training epoch while the y-axis denotes the percentage (%) of total samples.
  • Figure 4: Visualization of 1) the accuracy of high/low-confidence predictions (the blue and red curves) and the weight distributions (the dark curve) of our NorMatch; 2) the accuracy of predictions from the threshold-based FixMatch, i.e., the green curve; 3) the uncertainties of discriminative classifier and NFC (dashed gray and brown curve respectively) of our NorMatch.
  • Figure 5: Sensitivity of the model's performance to $\lambda$.
  • ...and 1 more figures