Table of Contents
Fetching ...

Spike-and-slab shrinkage priors for structurally sparse Bayesian neural networks

Sanket Jantre, Shrijita Bhattacharya, Tapabrata Maiti

TL;DR

This work tackles the challenge of achieving structurally sparse Bayesian neural networks with principled node pruning. It introduces two spike-and-slab priors, SS-GL and SS-GHS, and a variational inference framework with continuous relaxation to prune neurons by layer while maintaining predictive performance. The authors derive variational posterior contraction rates that depend on network topology and weight bounds, and validate the approach with extensive experiments on MLP MNIST, LeNet-5-Caffe, and ResNet CIFAR-10, showing improvements in accuracy, compression, and FLOPs. Overall, the paper provides both theoretical guarantees and practical methods for efficient, structured sparsity in Bayesian neural networks.

Abstract

Network complexity and computational efficiency have become increasingly significant aspects of deep learning. Sparse deep learning addresses these challenges by recovering a sparse representation of the underlying target function by reducing heavily over-parameterized deep neural networks. Specifically, deep neural architectures compressed via structured sparsity (e.g. node sparsity) provide low latency inference, higher data throughput, and reduced energy consumption. In this paper, we explore two well-established shrinkage techniques, Lasso and Horseshoe, for model compression in Bayesian neural networks. To this end, we propose structurally sparse Bayesian neural networks which systematically prune excessive nodes with (i) Spike-and-Slab Group Lasso (SS-GL), and (ii) Spike-and-Slab Group Horseshoe (SS-GHS) priors, and develop computationally tractable variational inference including continuous relaxation of Bernoulli variables. We establish the contraction rates of the variational posterior of our proposed models as a function of the network topology, layer-wise node cardinalities, and bounds on the network weights. We empirically demonstrate the competitive performance of our models compared to the baseline models in prediction accuracy, model compression, and inference latency.

Spike-and-slab shrinkage priors for structurally sparse Bayesian neural networks

TL;DR

This work tackles the challenge of achieving structurally sparse Bayesian neural networks with principled node pruning. It introduces two spike-and-slab priors, SS-GL and SS-GHS, and a variational inference framework with continuous relaxation to prune neurons by layer while maintaining predictive performance. The authors derive variational posterior contraction rates that depend on network topology and weight bounds, and validate the approach with extensive experiments on MLP MNIST, LeNet-5-Caffe, and ResNet CIFAR-10, showing improvements in accuracy, compression, and FLOPs. Overall, the paper provides both theoretical guarantees and practical methods for efficient, structured sparsity in Bayesian neural networks.

Abstract

Network complexity and computational efficiency have become increasingly significant aspects of deep learning. Sparse deep learning addresses these challenges by recovering a sparse representation of the underlying target function by reducing heavily over-parameterized deep neural networks. Specifically, deep neural architectures compressed via structured sparsity (e.g. node sparsity) provide low latency inference, higher data throughput, and reduced energy consumption. In this paper, we explore two well-established shrinkage techniques, Lasso and Horseshoe, for model compression in Bayesian neural networks. To this end, we propose structurally sparse Bayesian neural networks which systematically prune excessive nodes with (i) Spike-and-Slab Group Lasso (SS-GL), and (ii) Spike-and-Slab Group Horseshoe (SS-GHS) priors, and develop computationally tractable variational inference including continuous relaxation of Bernoulli variables. We establish the contraction rates of the variational posterior of our proposed models as a function of the network topology, layer-wise node cardinalities, and bounds on the network weights. We empirically demonstrate the competitive performance of our models compared to the baseline models in prediction accuracy, model compression, and inference latency.
Paper Structure (40 sections, 11 theorems, 69 equations, 3 figures, 2 tables, 1 algorithm)

This paper contains 40 sections, 11 theorems, 69 equations, 3 figures, 2 tables, 1 algorithm.

Key Result

Theorem 1

Let A.1-A.4 hold. Let $\epsilon_n=\sqrt{(\sum_{l=0}^L r_l+\xi)\sum_{l=0}^L u_l}$ with $r_l$, $u_l$ as in e:gl-r and $\xi$ as in e:xi-def. For a sequence $M_n \to \infty$, $M_n \epsilon_n \to 0$, the variational posterior $\widetilde{\Pi}^*$ satisfies,

Figures (3)

  • Figure 1: MNIST Experiment. We demonstrate the performance of SS-GL and SS-GHS models in a 2-layer perceptron network for classifying MNIST dataset. We include closely related SS-IG model Jantre-et-al-2023 for comparison. (a) Test data prediction accuracy. (b) and (c) proportion of active nodes (node sparsity) in layer-1 and layer-2 of the network respectively. SS-GHS achieves the best predictive performance and the most compact network, as seen by the proportion of active nodes in layer-1 and layer-2.
  • Figure 2: MLP-MNIST experiment. We compare the performance of SS-GHS, SS-GL, and SS-IG models on (a) classification accuracy, (b) model compression ratio, and (c) FLOPs ratio. The SS-GHS model achieves the highest accuracy, best compression, and lowest FLOPs, making it the most efficient in performance and resource use.
  • Figure 3: SS-GHS $c_{\rm reg}$ choice experiment. Performance of SS-GHS with regularization constant of $c_{\rm reg}=1$ and $c_{\rm reg}=k_l+1=401$. (a) Classification accuracy on the test data. (b) and (c) proportion of active nodes (node sparsity) in layer-1 and layer-2 of the network respectively. Both $c_{\rm reg}$ choices give similar classification accuracies and $c_{\rm reg}=1$ has better layer-1 node sparsity.

Theorems & Definitions (15)

  • Theorem 1
  • Corollary 2
  • Theorem 3
  • Corollary 4
  • Remark 5
  • Remark 6
  • Definition 7
  • Definition 8
  • Lemma 9
  • Lemma 10: Existence of Test Functions
  • ...and 5 more