Table of Contents
Fetching ...

Universal Training of Neural Networks to Achieve Bayes Optimal Classification Accuracy

Mohammadreza Tavasoli Naeini, Ali Bereyhi, Morteza Noshad, Ben Liang, Alfred O. Hero

TL;DR

The paper tackles the problem of approaching Bayes-optimal classification by deriving a universal, sample-based upper bound on the Bayes error via \\emph{f}-divergence and hinge-loss connections. This bound is reinterpreted as a trainable loss, the Bayes Optimal Learning Threshold (BOLT), which, when minimized, guarantees \\varepsilon_{bys} \\le \\min_{\\theta} \\mathcal{L}_{\\theta}. Empirical results on MNIST, Fashion-MNIST, CIFAR-10, and IMDb show that BOLT can match or surpass cross-entropy, particularly on harder datasets like CIFAR-10, indicating improved generalization. The work provides a principled objective to align training with Bayes accuracy and suggests potential extensions beyond traditional classification tasks.

Abstract

This work invokes the notion of $f$-divergence to introduce a novel upper bound on the Bayes error rate of a general classification task. We show that the proposed bound can be computed by sampling from the output of a parameterized model. Using this practical interpretation, we introduce the Bayes optimal learning threshold (BOLT) loss whose minimization enforces a classification model to achieve the Bayes error rate. We validate the proposed loss for image and text classification tasks, considering MNIST, Fashion-MNIST, CIFAR-10, and IMDb datasets. Numerical experiments demonstrate that models trained with BOLT achieve performance on par with or exceeding that of cross-entropy, particularly on challenging datasets. This highlights the potential of BOLT in improving generalization.

Universal Training of Neural Networks to Achieve Bayes Optimal Classification Accuracy

TL;DR

The paper tackles the problem of approaching Bayes-optimal classification by deriving a universal, sample-based upper bound on the Bayes error via \\emph{f}-divergence and hinge-loss connections. This bound is reinterpreted as a trainable loss, the Bayes Optimal Learning Threshold (BOLT), which, when minimized, guarantees \\varepsilon_{bys} \\le \\min_{\\theta} \\mathcal{L}_{\\theta}. Empirical results on MNIST, Fashion-MNIST, CIFAR-10, and IMDb show that BOLT can match or surpass cross-entropy, particularly on harder datasets like CIFAR-10, indicating improved generalization. The work provides a principled objective to align training with Bayes accuracy and suggests potential extensions beyond traditional classification tasks.

Abstract

This work invokes the notion of -divergence to introduce a novel upper bound on the Bayes error rate of a general classification task. We show that the proposed bound can be computed by sampling from the output of a parameterized model. Using this practical interpretation, we introduce the Bayes optimal learning threshold (BOLT) loss whose minimization enforces a classification model to achieve the Bayes error rate. We validate the proposed loss for image and text classification tasks, considering MNIST, Fashion-MNIST, CIFAR-10, and IMDb datasets. Numerical experiments demonstrate that models trained with BOLT achieve performance on par with or exceeding that of cross-entropy, particularly on challenging datasets. This highlights the potential of BOLT in improving generalization.
Paper Structure (14 sections, 3 theorems, 21 equations, 2 figures, 1 table)

This paper contains 14 sections, 3 theorems, 21 equations, 2 figures, 1 table.

Key Result

Lemma 1

Let $P$ and $Q$ be defined on $\mathcal{X}$. The $f$-divergence is bounded as with $f^*$ denoting the Fenchel conjugate of $f$ defined as and $\mathcal{H}$ being a suitable class of real-valued functions. The class $\mathcal{H}$ typically includes bounded measurable functions.

Figures (2)

  • Figure 1: Comparing Bayes error rate $\varepsilon_\mathrm{bys}$ with the classification error achieved by the neural network trained with BOLT loss: the trained model matches almost perfectly with $\varepsilon_\mathrm{bys}$.
  • Figure 2: Test accuracy of ResNet-18 trained on CIFAR-10 using cross-entropy (CE) and BOLT loss: the test accuracy improves by 1.34% at the final epoch with the BOLT loss.

Theorems & Definitions (9)

  • Definition 1: $f$-Divergence
  • Example 1: KL divergence
  • Lemma 1: Bounding $f$-Divergence nowozin2016fliese2006statisticalliese2006divergences
  • Example 2: Bounding KL divergence
  • Definition 2: Hinge loss
  • Theorem 1: Binary Bayes Error
  • proof
  • Theorem 2: Multi-Class Bayes error
  • proof