Table of Contents
Fetching ...

Variational Classification

Shehzaad Dhuliawala, Mrinmaya Sachan, Carl Allen

TL;DR

This work reframes standard softmax classification as a latent-variable model and introduces Variational Classification (VC), a probabilistic objective that aligns empirical latent distributions with a chosen anticipated distribution via an ELBO-like training objective. By treating the softmax input as a latent variable and using an encoder to model $q_\phi(z|x)$ and a Bayes-rule output layer $p_\theta(y|z)$, VC generalises cross-entropy and enables explicit control over latent priors and distributions. Empirically, VC achieves comparable accuracy to traditional softmax while substantially improving calibration, adversarial robustness, distribution-shift resilience, and sample efficiency in low-data settings, with minimal overhead and no extra hyperparameter tuning. The approach provides theoretical insight into softmax classifiers and suggests future extensions to stochastic latent encoders and semi-supervised learning.

Abstract

We present a latent variable model for classification that provides a novel probabilistic interpretation of neural network softmax classifiers. We derive a variational objective to train the model, analogous to the evidence lower bound (ELBO) used to train variational auto-encoders, that generalises the softmax cross-entropy loss. Treating inputs to the softmax layer as samples of a latent variable, our abstracted perspective reveals a potential inconsistency between their anticipated distribution, required for accurate label predictions, and their empirical distribution found in practice. We augment the variational objective to mitigate such inconsistency and induce a chosen latent distribution, instead of the implicit assumption found in a standard softmax layer. Overall, we provide new theoretical insight into the inner workings of widely-used softmax classifiers. Empirical evaluation on image and text classification datasets demonstrates that our proposed approach, variational classification, maintains classification accuracy while the reshaped latent space improves other desirable properties of a classifier, such as calibration, adversarial robustness, robustness to distribution shift and sample efficiency useful in low data settings.

Variational Classification

TL;DR

This work reframes standard softmax classification as a latent-variable model and introduces Variational Classification (VC), a probabilistic objective that aligns empirical latent distributions with a chosen anticipated distribution via an ELBO-like training objective. By treating the softmax input as a latent variable and using an encoder to model and a Bayes-rule output layer , VC generalises cross-entropy and enables explicit control over latent priors and distributions. Empirically, VC achieves comparable accuracy to traditional softmax while substantially improving calibration, adversarial robustness, distribution-shift resilience, and sample efficiency in low-data settings, with minimal overhead and no extra hyperparameter tuning. The approach provides theoretical insight into softmax classifiers and suggests future extensions to stochastic latent encoders and semi-supervised learning.

Abstract

We present a latent variable model for classification that provides a novel probabilistic interpretation of neural network softmax classifiers. We derive a variational objective to train the model, analogous to the evidence lower bound (ELBO) used to train variational auto-encoders, that generalises the softmax cross-entropy loss. Treating inputs to the softmax layer as samples of a latent variable, our abstracted perspective reveals a potential inconsistency between their anticipated distribution, required for accurate label predictions, and their empirical distribution found in practice. We augment the variational objective to mitigate such inconsistency and induce a chosen latent distribution, instead of the implicit assumption found in a standard softmax layer. Overall, we provide new theoretical insight into the inner workings of widely-used softmax classifiers. Empirical evaluation on image and text classification datasets demonstrates that our proposed approach, variational classification, maintains classification accuracy while the reshaped latent space improves other desirable properties of a classifier, such as calibration, adversarial robustness, robustness to distribution shift and sample efficiency useful in low data settings.
Paper Structure (25 sections, 20 equations, 8 figures, 4 tables, 1 algorithm)

This paper contains 25 sections, 20 equations, 8 figures, 4 tables, 1 algorithm.

Figures (8)

  • Figure 1: Empirical distributions of inputs to the output layer $q_\phi(z|y)$ for classifiers trained under incremental components of the VC objective (Eqn. \ref{['eq:LVC_2']}) on MNIST (cf the central ${\mathcal{Z}}$-plane in figure \ref{['fig:VC_diagram']}). (l) "MLE" objective = softmax cross-entropy; (c) "MAP" objective = MLE + Gaussian class priors $p_\theta({\textnormal{z}}|y)$ (in contour); (r) VC objective = MAP + entropy of $p_\theta({\textnormal{z}}|y)$. Colour indicates class $y$; $\mathcal{Z}\!=\!{\mathbb{R}}^{2}$ for visualisation purposes.
  • Figure 2: Variational Classification, reversing the generative process: $q_\phi({\textnormal{z}}|x)$ maps data $x\!\in\!\mathcal{X}$ to the latent space $\mathcal{Z}$, where empirical distributions $q_\phi({\textnormal{z}}|y)$ are fitted to class priors$p_\theta({\textnormal{z}}|y)$; top layer computes $p_\theta(y|z)$ by Bayes' rule to give a class prediction $p(y|x)$.
  • Figure 3: Calibration under distribution shift: (l) CIFAR-10-C, (m) CIFAR-100-C, (r) Tiny-Imagenet-C. Boxes indicate quartiles, whiskers indicate min/max, across 16 types of synthetic distribution shift.
  • Figure 4: Prediction accuracy for increasing FGSM adversarial attacks (l) MNIST; (r) CIFAR-10
  • Figure 5: Accuracy increase of VC vs CE on 10 MedMNIST classification datasets of varying training set size. Blue points indicate accuracy on a dataset (mean, std.err., 3 runs). Green line shows a best-fit trend across dataset size.
  • ...and 3 more figures