Table of Contents
Fetching ...

Gradient-free variational learning with conditional mixture networks

Conor Heins, Hao Wu, Dimitrije Markovic, Alexander Tschantz, Jeff Beck, Christopher Buckley

TL;DR

The paper addresses the challenge of obtaining calibrated predictions and uncertainty quantification without prohibitive computation in Bayesian neural-like models. It introduces CAVI-CMN, a gradient-free variational learning algorithm for training two-layer conditional mixture networks that exploits conditional conjugacy and Polya-Gamma augmentation to produce Gaussian posteriors and analytic updates. Empirically, CAVI-CMN matches or exceeds the predictive performance of gradient-based MLE while delivering full posterior distributions and calibrated predictions, with runtimes competitive with BBVI and NUTS and favorable scaling to larger models. This approach offers a practical, online-friendly Bayesian alternative for fast probabilistic networks and suggests extensions to deeper architectures and minibatch/streaming learning.

Abstract

Balancing computational efficiency with robust predictive performance is crucial in supervised learning, especially for critical applications. Standard deep learning models, while accurate and scalable, often lack probabilistic features like calibrated predictions and uncertainty quantification. Bayesian methods address these issues but can be computationally expensive as model and data complexity increase. Previous work shows that fast variational methods can reduce the compute requirements of Bayesian methods by eliminating the need for gradient computation or sampling, but are often limited to simple models. We introduce CAVI-CMN, a fast, gradient-free variational method for training conditional mixture networks (CMNs), a probabilistic variant of the mixture-of-experts (MoE) model. CMNs are composed of linear experts and a softmax gating network. By exploiting conditional conjugacy and Pólya-Gamma augmentation, we furnish Gaussian likelihoods for the weights of both the linear layers and the gating network. This enables efficient variational updates using coordinate ascent variational inference (CAVI), avoiding traditional gradient-based optimization. We validate this approach by training two-layer CMNs on standard classification benchmarks from the UCI repository. CAVI-CMN achieves competitive and often superior predictive accuracy compared to maximum likelihood estimation (MLE) with backpropagation, while maintaining competitive runtime and full posterior distributions over all model parameters. Moreover, as input size or the number of experts increases, computation time scales competitively with MLE and other gradient-based solutions like black-box variational inference (BBVI), making CAVI-CMN a promising tool for deep, fast, and gradient-free Bayesian networks.

Gradient-free variational learning with conditional mixture networks

TL;DR

The paper addresses the challenge of obtaining calibrated predictions and uncertainty quantification without prohibitive computation in Bayesian neural-like models. It introduces CAVI-CMN, a gradient-free variational learning algorithm for training two-layer conditional mixture networks that exploits conditional conjugacy and Polya-Gamma augmentation to produce Gaussian posteriors and analytic updates. Empirically, CAVI-CMN matches or exceeds the predictive performance of gradient-based MLE while delivering full posterior distributions and calibrated predictions, with runtimes competitive with BBVI and NUTS and favorable scaling to larger models. This approach offers a practical, online-friendly Bayesian alternative for fast probabilistic networks and suggests extensions to deeper architectures and minibatch/streaming learning.

Abstract

Balancing computational efficiency with robust predictive performance is crucial in supervised learning, especially for critical applications. Standard deep learning models, while accurate and scalable, often lack probabilistic features like calibrated predictions and uncertainty quantification. Bayesian methods address these issues but can be computationally expensive as model and data complexity increase. Previous work shows that fast variational methods can reduce the compute requirements of Bayesian methods by eliminating the need for gradient computation or sampling, but are often limited to simple models. We introduce CAVI-CMN, a fast, gradient-free variational method for training conditional mixture networks (CMNs), a probabilistic variant of the mixture-of-experts (MoE) model. CMNs are composed of linear experts and a softmax gating network. By exploiting conditional conjugacy and Pólya-Gamma augmentation, we furnish Gaussian likelihoods for the weights of both the linear layers and the gating network. This enables efficient variational updates using coordinate ascent variational inference (CAVI), avoiding traditional gradient-based optimization. We validate this approach by training two-layer CMNs on standard classification benchmarks from the UCI repository. CAVI-CMN achieves competitive and often superior predictive accuracy compared to maximum likelihood estimation (MLE) with backpropagation, while maintaining competitive runtime and full posterior distributions over all model parameters. Moreover, as input size or the number of experts increases, computation time scales competitively with MLE and other gradient-based solutions like black-box variational inference (BBVI), making CAVI-CMN a promising tool for deep, fast, and gradient-free Bayesian networks.
Paper Structure (28 sections, 32 equations, 9 figures, 1 table)

This paper contains 28 sections, 32 equations, 9 figures, 1 table.

Figures (9)

  • Figure 1: A Bayesian network representation of the two-layer conditional mixture network, with input-output pairs $\pmb{x}^n_0, y^n$ and latent variables $\pmb{x}^n_1, z^n_1$. Observations are shaded nodes, while latents and parameters are transparent. Prior hyperparameters are shown without boundaries.
  • Figure 2: Performance and runtime results of the different inference algorithms on the 'Pinwheel' dataset. The standard deviation (vertical lines) of the performance metric is depicted together with the mean estimate (circles) over different runs. The top row of subplots show performance metrics across training set sizes: test accuracy (top left); log predictive density (top center), and expected calibration error (top right). The bottom row shows runtime metrics as a function of increasing training set size: the number of iterations required to achieve convergence (lower left); and the total runtime, estimated using the product of the number of iterations to convergence and the average cost (in seconds) for running one iteration (lower right). The number of iterations required for convergence was calculated by determining the number of gradient steps (or M steps, for CAVI) taken before the ELBO (or negative log likelihood, for MLE) reached 95% of its maximum value (see \ref{['app: convergence_details']} for details on how these metrics were computed).
  • Figure 3: Performance and runtime results of the different models on the 'Waveform Domains' dataset. The waveforms dataset consists of synthetic data generated to classify three different waveform patterns. Each instance is described by 21 continuous attributes. See https://archive.ics.uci.edu/dataset/107/waveform+database+generator+version+1 for more information about the dataset. Descriptions of each subplot are same as in the \ref{['fig:pinwheel_performance']} legend.
  • Figure 4: Relative scaling of fitting time in seconds for Maximum Likelihood, BBVI, and CAVI, as a function of the number of parameters. The number of parameters itself was manipulated in three illustrative ways: changing the input dimension $d$, changing the number of linear experts $K$ in the conditional mixture layer, and changing the dimensionality of the continuous latent variable $h$.
  • Figure 5: Performance and runtime results of the different models on the 'Vehicle Silhouettes' dataset. Descriptions of each subplot are same as in the \ref{['fig:pinwheel_performance']} legend.
  • ...and 4 more figures