Table of Contents
Fetching ...

A Principled Bayesian Framework for Training Binary and Spiking Neural Networks

James A. Walker, Moein Khajehnejad, Adeel Razi

TL;DR

This work tackles the difficulty of training binary and spiking neural networks with non-differentiable activations by introducing a principled Bayesian framework that uses IW-ST(p) estimators to unify straight-through and continuous-relaxation approaches. Through variational inference and local reparameterisation, the authors instantiate Spiking Bayesian Neural Networks (SBNNs) that incorporate noise and KL regularisation, enabling end-to-end training without normalisation layers or hand-tuned surrogates. They also develop the Analytical Gumbel–Rao (AGR) estimator and a broader IW-ST(p) family to balance bias and variance in deep networks. Empirically, the method achieves state-of-the-art or competitive results on CIFAR-10, DVS Gesture, and SHD, highlighting the practical value of principled Bayesian noise for discrete and temporal neural models.

Abstract

We propose a Bayesian framework for training binary and spiking neural networks that achieves state-of-the-art performance without normalisation layers. Unlike commonly used surrogate gradient methods -- often heuristic and sensitive to hyperparameter choices -- our approach is grounded in a probabilistic model of noisy binary networks, enabling fully end-to-end gradient-based optimisation. We introduce importance-weighted straight-through (IW-ST) estimators, a unified class generalising straight-through and relaxation-based estimators. We characterise the bias-variance trade-off in this family and derive a bias-minimising objective implemented via an auxiliary loss. Building on this, we introduce Spiking Bayesian Neural Networks (SBNNs), a variational inference framework that uses posterior noise to train Binary and Spiking Neural Networks with IW-ST. This Bayesian approach minimises gradient bias, regularises parameters, and introduces dropout-like noise. By linking low-bias conditions, vanishing gradients, and the KL term, we enable training of deep residual networks without normalisation. Experiments on CIFAR-10, DVS Gesture, and SHD show our method matches or exceeds existing approaches without normalisation or hand-tuned gradients.

A Principled Bayesian Framework for Training Binary and Spiking Neural Networks

TL;DR

This work tackles the difficulty of training binary and spiking neural networks with non-differentiable activations by introducing a principled Bayesian framework that uses IW-ST(p) estimators to unify straight-through and continuous-relaxation approaches. Through variational inference and local reparameterisation, the authors instantiate Spiking Bayesian Neural Networks (SBNNs) that incorporate noise and KL regularisation, enabling end-to-end training without normalisation layers or hand-tuned surrogates. They also develop the Analytical Gumbel–Rao (AGR) estimator and a broader IW-ST(p) family to balance bias and variance in deep networks. Empirically, the method achieves state-of-the-art or competitive results on CIFAR-10, DVS Gesture, and SHD, highlighting the practical value of principled Bayesian noise for discrete and temporal neural models.

Abstract

We propose a Bayesian framework for training binary and spiking neural networks that achieves state-of-the-art performance without normalisation layers. Unlike commonly used surrogate gradient methods -- often heuristic and sensitive to hyperparameter choices -- our approach is grounded in a probabilistic model of noisy binary networks, enabling fully end-to-end gradient-based optimisation. We introduce importance-weighted straight-through (IW-ST) estimators, a unified class generalising straight-through and relaxation-based estimators. We characterise the bias-variance trade-off in this family and derive a bias-minimising objective implemented via an auxiliary loss. Building on this, we introduce Spiking Bayesian Neural Networks (SBNNs), a variational inference framework that uses posterior noise to train Binary and Spiking Neural Networks with IW-ST. This Bayesian approach minimises gradient bias, regularises parameters, and introduces dropout-like noise. By linking low-bias conditions, vanishing gradients, and the KL term, we enable training of deep residual networks without normalisation. Experiments on CIFAR-10, DVS Gesture, and SHD show our method matches or exceeds existing approaches without normalisation or hand-tuned gradients.

Paper Structure

This paper contains 6 sections, 10 theorems, 93 equations, 6 figures, 4 tables.

Key Result

Proposition 1

The recursive form of the ST estimator gradient matches the recursive SG expression: under the identification $S'(h_i^{(l)}) \approx F'(h_i^{(l)})$ and $\mathbb{E}\left[\mathcal{L}(o_i^{(l)} = 1) - \mathcal{L}(o_i^{(l)} = 0)\right] =\frac{d\mathcal{L}}{do_i^{(l)}}$, where $S(\cdot)$ is the surrogate function and $F(\cdot)$ is the spiking probability.

Figures (6)

  • Figure 1: Training loss and test accuracy across epochs for a) CIFAR-10, b) DVS-Gesture, and c) SHD datasets
  • Figure 2: Training loss and test accuracy for IW-ST($p$) estimators on CIFAR-10
  • Figure 3: Training loss and test accuracy for AGR estimators on CIFAR-10
  • Figure 4: Gradient norms per layer after 30 epochs
  • Figure 5: Attenuation factor over first 30 Epochs
  • ...and 1 more figures

Theorems & Definitions (27)

  • Proposition 1: Equivalence to surrogate gradient
  • Definition 1: Analytical Gumbel-Rao estimator (AGR)
  • Remark 1: Gradient damping
  • Remark 2: Bias-variance trade-off
  • Definition 2: Importance-weighted straight-through estimator (IW-ST($p$))
  • Remark 3
  • Theorem 1
  • proof : Proof sketch
  • Corollary 1
  • proof
  • ...and 17 more