Table of Contents
Fetching ...

Structured Partial Stochasticity in Bayesian Neural Networks

Tommy Rochussen

TL;DR

The paper tackles posterior multimodality in Bayesian neural networks caused by neuron permutation symmetries, which complicates approximate inference. It introduces structured partial stochasticity by deterministically fixing a subset of weights to break intra-layer symmetries, with two schemes—Structured Light Pruning and Structured Heavy Pruning—and a generalization to fixing nonzero constants or MAP values, ensuring $D_{ ext{l}}^{(fc)} = \max(0, D_{ ext{l}} - D_{ ext{l-1}} - D_{ ext{l+1}} + 2) \le 1$. The approach yields a simpler posterior and, empirically, improves mean-field variational inference on toy and real regression tasks, with better predictive performance and calibrated uncertainty. Overall, symmetry removal via structured fixing provides a practical route to more accurate and scalable Bayesian inference in neural networks.

Abstract

Bayesian neural network posterior distributions have a great number of modes that correspond to the same network function. The abundance of such modes can make it difficult for approximate inference methods to do their job. Recent work has demonstrated the benefits of partial stochasticity for approximate inference in Bayesian neural networks; inference can be less costly and performance can sometimes be improved. I propose a structured way to select the deterministic subset of weights that removes neuron permutation symmetries, and therefore the corresponding redundant posterior modes. With a drastically simplified posterior distribution, the performance of existing approximate inference schemes is found to be greatly improved.

Structured Partial Stochasticity in Bayesian Neural Networks

TL;DR

The paper tackles posterior multimodality in Bayesian neural networks caused by neuron permutation symmetries, which complicates approximate inference. It introduces structured partial stochasticity by deterministically fixing a subset of weights to break intra-layer symmetries, with two schemes—Structured Light Pruning and Structured Heavy Pruning—and a generalization to fixing nonzero constants or MAP values, ensuring . The approach yields a simpler posterior and, empirically, improves mean-field variational inference on toy and real regression tasks, with better predictive performance and calibrated uncertainty. Overall, symmetry removal via structured fixing provides a practical route to more accurate and scalable Bayesian inference in neural networks.

Abstract

Bayesian neural network posterior distributions have a great number of modes that correspond to the same network function. The abundance of such modes can make it difficult for approximate inference methods to do their job. Recent work has demonstrated the benefits of partial stochasticity for approximate inference in Bayesian neural networks; inference can be less costly and performance can sometimes be improved. I propose a structured way to select the deterministic subset of weights that removes neuron permutation symmetries, and therefore the corresponding redundant posterior modes. With a drastically simplified posterior distribution, the performance of existing approximate inference schemes is found to be greatly improved.
Paper Structure (18 sections, 7 equations, 3 figures, 3 tables)

This paper contains 18 sections, 7 equations, 3 figures, 3 tables.

Figures (3)

  • Figure 1: Connections that are fixed under structured partial stochasticity. \ref{['subfig:nodes1']} and \ref{['subfig:nodes2']} correspond to the light scheme, while \ref{['subfig:nodes3']} corresponds to the heavy scheme. Connections with weights that are free are omitted except for a connection (dashed line) from neuron $h_5$ in \ref{['subfig:nodes2']} and \ref{['subfig:nodes3']}, whose weight takes near-zero value.
  • Figure 2: One-dimensional regression. Blue circles indicate datapoints, black lines and shaded areas represent predictive means and 95% confidence regions respectively.
  • Figure 3: Fixed connections under the respective light and heavy schemes of structured partial stochasticity.