Joint-stochastic-approximation Random Fields with Application to Semi-supervised Learning
Yunfu Song, Zhijian Ou
TL;DR
The paper tackles SSL with deep generative models by addressing two core issues: mode coverage/missing and the conflict between classification and generation in directed models. It introduces Joint-stochastic-approximation random fields (JRFs), deep undirected energy-based models trained via Joint-stochastic-approximation (JSA) that pair a target RF $p_\theta(x)$ with an auxiliary directed generator $q_\beta(h|x)$ and use Langevin-style sampling to draw samples. Through SA-based learning, JRFs jointly optimize the RF and the auxiliary model, and are extended to SSL by modeling $(x,y)$ with a joint energy and incorporating supervised terms and regularizers $R_c$ and $R_s$. Empirically, JRFs achieve competitive SSL classification on MNIST, SVHN, and CIFAR-10 while also delivering high-quality generation, with advantages over GAN-based approaches in avoiding mode collapse and over EBGMs in distribution matching. The work demonstrates, for the first time, that deep random-field models can effectively support SSL, suggesting a promising direction for undirected deep generative modeling in practical learning tasks.
Abstract
Our examination of deep generative models (DGMs) developed for semi-supervised learning (SSL), mainly GANs and VAEs, reveals two problems. First, mode missing and mode covering phenomenons are observed in genertion with GANs and VAEs. Second, there exists an awkward conflict between good classification and good generation in SSL by employing directed generative models. To address these problems, we formally present joint-stochastic-approximation random fields (JRFs) -- a new family of algorithms for building deep undirected generative models, with application to SSL. It is found through synthetic experiments that JRFs work well in balancing mode covering and mode missing, and match the empirical data distribution well. Empirically, JRFs achieve good classification results comparable to the state-of-art methods on widely adopted datasets -- MNIST, SVHN, and CIFAR-10 in SSL, and simultaneously perform good generation.
