Table of Contents
Fetching ...

Critic Loss for Image Classification

Brendan Hogan Rappazzo, Aaron Ferber, Carla Gomes

TL;DR

CrtCl proposes a learned loss for image classification by framing training as a generator-critic game between a classifier $G_{ heta}$ and a correctness critic $C_{\,\phi}$. The critic estimates the probability that the classifier's prediction is correct, enabling a Wasserstein-distance-based loss $\mathcal{L}_{cc}$ that is backpropagated to the critic while the standard cross-entropy guides the classifier, allowing learning from both labeled and unlabeled data. The framework supports semi-supervised learning and active learning by applying the critic to unlabeled samples and by selecting informative labeled examples, respectively. Across SVHN, CIFAR10, and CIFAR100, CrtCl achieves higher generalization and better calibration (lower ECE) than strong baselines, particularly in low-label regimes, with consistent improvements and a transparent computational trade-off.

Abstract

Modern neural network classifiers achieve remarkable performance across a variety of tasks; however, they frequently exhibit overconfidence in their predictions due to the cross-entropy loss. Inspired by this problem, we propose the \textbf{Cr}i\textbf{t}ic Loss for Image \textbf{Cl}assification (CrtCl, pronounced Critical). CrtCl formulates image classification training in a generator-critic framework, with a base classifier acting as a generator, and a correctness critic imposing a loss on the classifier. The base classifier, acting as the generator, given images, generates the probability distribution over classes and intermediate embeddings. The critic model, given the image, intermediate embeddings, and output predictions of the base model, predicts the probability that the base model has produced the correct classification, which then can be back propagated as a self supervision signal. Notably, the critic does not use the label as input, meaning that the critic can train the base model on both labeled and unlabeled data in semi-supervised learning settings. CrtCl represents a learned loss method for accuracy, alleviating the negative side effects of using cross-entropy loss. Additionally, CrtCl provides a powerful way to select data to be labeled in an active learning setting, by estimating the classification ability of the base model on unlabeled data. We study the effectiveness of CrtCl in low-labeled data regimes, and in the context of active learning. In classification, we find that CrtCl, compared to recent baselines, increases classifier generalization and calibration with various amounts of labeled data. In active learning, we show our method outperforms baselines in accuracy and calibration. We observe consistent results across three image classification datasets.

Critic Loss for Image Classification

TL;DR

CrtCl proposes a learned loss for image classification by framing training as a generator-critic game between a classifier and a correctness critic . The critic estimates the probability that the classifier's prediction is correct, enabling a Wasserstein-distance-based loss that is backpropagated to the critic while the standard cross-entropy guides the classifier, allowing learning from both labeled and unlabeled data. The framework supports semi-supervised learning and active learning by applying the critic to unlabeled samples and by selecting informative labeled examples, respectively. Across SVHN, CIFAR10, and CIFAR100, CrtCl achieves higher generalization and better calibration (lower ECE) than strong baselines, particularly in low-label regimes, with consistent improvements and a transparent computational trade-off.

Abstract

Modern neural network classifiers achieve remarkable performance across a variety of tasks; however, they frequently exhibit overconfidence in their predictions due to the cross-entropy loss. Inspired by this problem, we propose the \textbf{Cr}i\textbf{t}ic Loss for Image \textbf{Cl}assification (CrtCl, pronounced Critical). CrtCl formulates image classification training in a generator-critic framework, with a base classifier acting as a generator, and a correctness critic imposing a loss on the classifier. The base classifier, acting as the generator, given images, generates the probability distribution over classes and intermediate embeddings. The critic model, given the image, intermediate embeddings, and output predictions of the base model, predicts the probability that the base model has produced the correct classification, which then can be back propagated as a self supervision signal. Notably, the critic does not use the label as input, meaning that the critic can train the base model on both labeled and unlabeled data in semi-supervised learning settings. CrtCl represents a learned loss method for accuracy, alleviating the negative side effects of using cross-entropy loss. Additionally, CrtCl provides a powerful way to select data to be labeled in an active learning setting, by estimating the classification ability of the base model on unlabeled data. We study the effectiveness of CrtCl in low-labeled data regimes, and in the context of active learning. In classification, we find that CrtCl, compared to recent baselines, increases classifier generalization and calibration with various amounts of labeled data. In active learning, we show our method outperforms baselines in accuracy and calibration. We observe consistent results across three image classification datasets.
Paper Structure (25 sections, 6 equations, 5 figures, 1 table, 1 algorithm)

This paper contains 25 sections, 6 equations, 5 figures, 1 table, 1 algorithm.

Figures (5)

  • Figure 1: A schematic of CrtCl, the classifier $G_{\theta}$ takes in images and produces intermediate representations and class probabilities, where the $\mathtt{ARGMAX}$ of the probabilities is the classification. The critic network, $C_{\phi}$, takes in the representations of $G_{\theta}$ and predicts whether $G_{\theta}$ classifies an example correctly. Once trained, the critic network can be used as a learned loss on both labeled and unlabeled data to train the generator to be more correct, while avoiding miscalibration from cross-entropy loss. Further, the critic model's prediction on unlabeled data points can be used to suggest misclassified points for active learning.
  • Figure 2: The accuracy and expected calibration error (ECE) results of our method, compared to Learning Loss ll_pp, TOD huang2022temporal, PT4AL pretrain, and a standard training baseline for all three data sets. For all datasets, in the majority of active learning cycles, our method produced the more generalizable models (higher test accuracy), and better calibrated models (lower ECE).
  • Figure 3: For this setting we took a standard network trained with cross-entropy on CIFAR10, and produced t-SNE embeddings for the entire data set. We then show, for the highest performing methods, at each cycle which data points are selected by each method to be labeled. Fully opaque points represent those chosen to be labeled, and the unlabeled points are shown as highly transparent. Additionally we compute the mean silhouette score for each cluster. Intuitively, higher scores indicate the model is selecting points more similar to the data already in the labeled set. Whereas, the lower the score, the more the model is selecting more diverse samples that are spread out further in the feature space. We observe that early on in training, our model selects more similar samples, and later on in training it achieves the lowest clustering score, indicating it selects the most diverse samples per class.
  • Figure 4: Our first ablation study which, for all methods, samples the next data points to be labeled randomly, but still uses the auxiliary loss function of all methods for the CIFAR10 data set. It is shown that our loss functions leads to better generalization (higher accuracy), and better calibration (lower ECE).
  • Figure 5: Our second ablation study which, for all methods, samples the next data points using the described methods, but does not backpropagate the auxiliary loss, for the CIFAR10 dataset. While our method still outperforms Learning Loss, TOD performs the best. Indicating that the benefits of our model are more entangled with the critic loss, which intuitively is expected given the nature of the generator-critic learning process.