A Method of Moments Embedding Constraint and its Application to Semi-Supervised Learning
Michael Majurski, Sumeet Menon, Parniyan Farvardin, David Chapman
TL;DR
This work tackles over-confidence in discriminative softmax classifiers by modeling the joint distribution $p(Y,X)$ through a differentiable, generative final-layer pair: an Axis-Aligned Gaussian Mixture Model (AAGMM) and a Method of Moments (MoM) embedding constraint. By constraining latent representations to match moments of a standard multivariate normal up to 4th order and integrating an AAGMM final layer, the approach yields latent clusters that reflect the data structure while preserving discriminative performance. Empirical results on CIFAR-10 and STL-10 with 40 labels show that MoM constraints, particularly with AAGMM, achieve competitive accuracy relative to FlexMatch and provide a framework for outlier detection via Mahalanobis distance. The combination offers a path to robust semi-supervised learning that can flag atypical samples and better capture the latent data geometry, albeit with higher GPU memory demands for higher-order moments and sensitivity to outlier thresholds.
Abstract
Discriminative deep learning models with a linear+softmax final layer have a problem: the latent space only predicts the conditional probabilities $p(Y|X)$ but not the full joint distribution $p(Y,X)$, which necessitates a generative approach. The conditional probability cannot detect outliers, causing outlier sensitivity in softmax networks. This exacerbates model over-confidence impacting many problems, such as hallucinations, confounding biases, and dependence on large datasets. To address this we introduce a novel embedding constraint based on the Method of Moments (MoM). We investigate the use of polynomial moments ranging from 1st through 4th order hyper-covariance matrices. Furthermore, we use this embedding constraint to train an Axis-Aligned Gaussian Mixture Model (AAGMM) final layer, which learns not only the conditional, but also the joint distribution of the latent space. We apply this method to the domain of semi-supervised image classification by extending FlexMatch with our technique. We find our MoM constraint with the AAGMM layer is able to match the reported FlexMatch accuracy, while also modeling the joint distribution, thereby reducing outlier sensitivity. We also present a preliminary outlier detection strategy based on Mahalanobis distance and discuss future improvements to this strategy. Code is available at: \url{https://github.com/mmajurski/ssl-gmm}
