Structured Partial Stochasticity in Bayesian Neural Networks
Tommy Rochussen
TL;DR
The paper tackles posterior multimodality in Bayesian neural networks caused by neuron permutation symmetries, which complicates approximate inference. It introduces structured partial stochasticity by deterministically fixing a subset of weights to break intra-layer symmetries, with two schemes—Structured Light Pruning and Structured Heavy Pruning—and a generalization to fixing nonzero constants or MAP values, ensuring $D_{ ext{l}}^{(fc)} = \max(0, D_{ ext{l}} - D_{ ext{l-1}} - D_{ ext{l+1}} + 2) \le 1$. The approach yields a simpler posterior and, empirically, improves mean-field variational inference on toy and real regression tasks, with better predictive performance and calibrated uncertainty. Overall, symmetry removal via structured fixing provides a practical route to more accurate and scalable Bayesian inference in neural networks.
Abstract
Bayesian neural network posterior distributions have a great number of modes that correspond to the same network function. The abundance of such modes can make it difficult for approximate inference methods to do their job. Recent work has demonstrated the benefits of partial stochasticity for approximate inference in Bayesian neural networks; inference can be less costly and performance can sometimes be improved. I propose a structured way to select the deterministic subset of weights that removes neuron permutation symmetries, and therefore the corresponding redundant posterior modes. With a drastically simplified posterior distribution, the performance of existing approximate inference schemes is found to be greatly improved.
