A Principled Bayesian Framework for Training Binary and Spiking Neural Networks
James A. Walker, Moein Khajehnejad, Adeel Razi
TL;DR
This work tackles the difficulty of training binary and spiking neural networks with non-differentiable activations by introducing a principled Bayesian framework that uses IW-ST(p) estimators to unify straight-through and continuous-relaxation approaches. Through variational inference and local reparameterisation, the authors instantiate Spiking Bayesian Neural Networks (SBNNs) that incorporate noise and KL regularisation, enabling end-to-end training without normalisation layers or hand-tuned surrogates. They also develop the Analytical Gumbel–Rao (AGR) estimator and a broader IW-ST(p) family to balance bias and variance in deep networks. Empirically, the method achieves state-of-the-art or competitive results on CIFAR-10, DVS Gesture, and SHD, highlighting the practical value of principled Bayesian noise for discrete and temporal neural models.
Abstract
We propose a Bayesian framework for training binary and spiking neural networks that achieves state-of-the-art performance without normalisation layers. Unlike commonly used surrogate gradient methods -- often heuristic and sensitive to hyperparameter choices -- our approach is grounded in a probabilistic model of noisy binary networks, enabling fully end-to-end gradient-based optimisation. We introduce importance-weighted straight-through (IW-ST) estimators, a unified class generalising straight-through and relaxation-based estimators. We characterise the bias-variance trade-off in this family and derive a bias-minimising objective implemented via an auxiliary loss. Building on this, we introduce Spiking Bayesian Neural Networks (SBNNs), a variational inference framework that uses posterior noise to train Binary and Spiking Neural Networks with IW-ST. This Bayesian approach minimises gradient bias, regularises parameters, and introduces dropout-like noise. By linking low-bias conditions, vanishing gradients, and the KL term, we enable training of deep residual networks without normalisation. Experiments on CIFAR-10, DVS Gesture, and SHD show our method matches or exceeds existing approaches without normalisation or hand-tuned gradients.
