Table of Contents
Fetching ...

Feature Preserving Shrinkage on Bayesian Neural Networks via the R2D2 Prior

Tsai Hor Chan, Dora Yan Zhang, Guosheng Yin, Lequan Yu

TL;DR

The paper tackles the challenge of prior choice in Bayesian neural networks by introducing the R2D2-Net, which uses the $R^2$-induced Dirichlet Decomposition (R2D2) prior to achieve strong sparsity without sacrificing important signals. A variational Gibbs inference scheme is developed to jointly estimate weights and shrinkage parameters, with closed-form KL divergences for several shrinkage components to yield accurate ELBO optimization. The authors establish a posterior contraction result, showing minimax-type convergence under regularity conditions, and demonstrate through simulations and real-data experiments that R2D2-Net attains superior predictive performance and more reliable uncertainty estimates, including robust OOD detection. The approach offers a principled, scalable framework for shrinkage in deep Bayesian models and holds promise for applications in medical imaging and broader Bayesian deep learning contexts.

Abstract

Bayesian neural networks (BNNs) treat neural network weights as random variables, which aim to provide posterior uncertainty estimates and avoid overfitting by performing inference on the posterior weights. However, the selection of appropriate prior distributions remains a challenging task, and BNNs may suffer from catastrophic inflated variance or poor predictive performance when poor choices are made for the priors. Existing BNN designs apply different priors to weights, while the behaviours of these priors make it difficult to sufficiently shrink noisy signals or they are prone to overshrinking important signals in the weights. To alleviate this problem, we propose a novel R2D2-Net, which imposes the R^2-induced Dirichlet Decomposition (R2D2) prior to the BNN weights. The R2D2-Net can effectively shrink irrelevant coefficients towards zero, while preventing key features from over-shrinkage. To approximate the posterior distribution of weights more accurately, we further propose a variational Gibbs inference algorithm that combines the Gibbs updating procedure and gradient-based optimization. This strategy enhances stability and consistency in estimation when the variational objective involving the shrinkage parameters is non-convex. We also analyze the evidence lower bound (ELBO) and the posterior concentration rates from a theoretical perspective. Experiments on both natural and medical image classification and uncertainty estimation tasks demonstrate satisfactory performance of our method.

Feature Preserving Shrinkage on Bayesian Neural Networks via the R2D2 Prior

TL;DR

The paper tackles the challenge of prior choice in Bayesian neural networks by introducing the R2D2-Net, which uses the -induced Dirichlet Decomposition (R2D2) prior to achieve strong sparsity without sacrificing important signals. A variational Gibbs inference scheme is developed to jointly estimate weights and shrinkage parameters, with closed-form KL divergences for several shrinkage components to yield accurate ELBO optimization. The authors establish a posterior contraction result, showing minimax-type convergence under regularity conditions, and demonstrate through simulations and real-data experiments that R2D2-Net attains superior predictive performance and more reliable uncertainty estimates, including robust OOD detection. The approach offers a principled, scalable framework for shrinkage in deep Bayesian models and holds promise for applications in medical imaging and broader Bayesian deep learning contexts.

Abstract

Bayesian neural networks (BNNs) treat neural network weights as random variables, which aim to provide posterior uncertainty estimates and avoid overfitting by performing inference on the posterior weights. However, the selection of appropriate prior distributions remains a challenging task, and BNNs may suffer from catastrophic inflated variance or poor predictive performance when poor choices are made for the priors. Existing BNN designs apply different priors to weights, while the behaviours of these priors make it difficult to sufficiently shrink noisy signals or they are prone to overshrinking important signals in the weights. To alleviate this problem, we propose a novel R2D2-Net, which imposes the R^2-induced Dirichlet Decomposition (R2D2) prior to the BNN weights. The R2D2-Net can effectively shrink irrelevant coefficients towards zero, while preventing key features from over-shrinkage. To approximate the posterior distribution of weights more accurately, we further propose a variational Gibbs inference algorithm that combines the Gibbs updating procedure and gradient-based optimization. This strategy enhances stability and consistency in estimation when the variational objective involving the shrinkage parameters is non-convex. We also analyze the evidence lower bound (ELBO) and the posterior concentration rates from a theoretical perspective. Experiments on both natural and medical image classification and uncertainty estimation tasks demonstrate satisfactory performance of our method.

Paper Structure

This paper contains 30 sections, 4 theorems, 35 equations, 14 figures, 12 tables.

Key Result

Theorem 1

Consider a DNN with $L_n$ layers and at most $K_n$ connections, where both $L_n$ and $K_n$ are increasing with $n$. Let $k_n \asymp \sqrt{r_n(\log K_n) / n}$ and denote $\mathbb{P}^*$ and $\mathbb{E}^*$ the respective probability measure and expectation with respect to the data $\mathcal{D}$. Assume for sufficiently large $n$, where $c$ is a constant, $d$ is the Hellinger distance between two dens

Figures (14)

  • Figure 1: An illustrative comparison of priors with different tail behaviors and concentration rates at zero. A prior with a heavier tail preserves stronger signals by putting more weights on them. A prior with a larger concentration rate around zero can shrink unnecessary or trivial features more effectively.
  • Figure 2: Overview of the proposed R2D2-Net with the yellow part representing the graphical model of each neuron and the blue part summarizing the variational Gibbs inference for computing the posterior distribution of weights.
  • Figure 3: Prediction mean and confidence intervals of R2D2-Net at test time on $y_i = x_i^3 + \epsilon_i, \epsilon_i \sim \mathcal{N}(0, 9)$. The number of layers is 3 and the number of samples is 100 during the validation phase. The blue dots are the ground truth data points, the yellow line is the mean of prediction and the blue shadow is the prediction interval. We observe that the R2D2-Net yields a smaller prediction variance than MC Dropout, Gaussian BNN, and Horseshoe BNN.
  • Figure 4: Ablation studies of our method to different hyperparameters. We run the three simulation scenarios (S1--S3) with an R2D2 MLP with $L = 3$, and report the testing MSEs with respect to different values of hyperparameters $a_\pi$ (left), $b$ (middle), $\rho_0$ (right).
  • Figure 5: Density plots of the weight samples of Gaussian BNN, Horseshoe BNN, and R2D2-Net. We choose the weights that have the least magnitude from the first layer of a three-layer MLP. We observe that R2D2-Net has the highest concentration rate at zero.
  • ...and 9 more figures

Theorems & Definitions (5)

  • Theorem 1
  • Remark 1.1
  • Theorem 2
  • Lemma 1
  • Lemma 2