Table of Contents
Fetching ...

Dirichlet Scale Mixture Priors for Bayesian Neural Networks

August Arnstad, Leiv Rønneberg, Geir Storvik

TL;DR

This work introduces Dirichlet Scale Mixture priors for Bayesian neural networks, a hierarchical, sparsity-inducing prior that partitions a variance budget across weight groups via a Dirichlet allocation. The DSM framework provides global, group, and local shrinkage, connecting to horseshoe-type priors while enabling structured sparsity aligned with neural network architecture. The authors derive theoretical properties of the induced dependence and shrinkage, linearize single-hidden-layer networks to study shrinkage directions, and demonstrate through linear and real data experiments that DSM priors yield sparser models with competitive predictive accuracy and greater pruning robustness, particularly in small-to-moderate data regimes. The approach also shows robustness against adversarial perturbations and offers a principled mechanism to counter cold posterior effects via heavy-tailed shrinkage, making it a compelling option for calibrated, efficient BNNs.

Abstract

Neural networks are the cornerstone of modern machine learning, yet can be difficult to interpret, give overconfident predictions and are vulnerable to adversarial attacks. Bayesian neural networks (BNNs) provide some alleviation of these limitations, but have problems of their own. The key step of specifying prior distributions in BNNs is no trivial task, yet is often skipped out of convenience. In this work, we propose a new class of prior distributions for BNNs, the Dirichlet scale mixture (DSM) prior, that addresses current limitations in Bayesian neural networks through structured, sparsity-inducing shrinkage. Theoretically, we derive general dependence structures and shrinkage results for DSM priors and show how they manifest under the geometry induced by neural networks. In experiments on simulated and real world data we find that the DSM priors encourages sparse networks through implicit feature selection, show robustness under adversarial attacks and deliver competitive predictive performance with substantially fewer effective parameters. In particular, their advantages appear most pronounced in correlated, moderately small data regimes, and are more amenable to weight pruning. Moreover, by adopting heavy-tailed shrinkage mechanisms, our approach aligns with recent findings that such priors can mitigate the cold posterior effect, offering a principled alternative to the commonly used Gaussian priors.

Dirichlet Scale Mixture Priors for Bayesian Neural Networks

TL;DR

This work introduces Dirichlet Scale Mixture priors for Bayesian neural networks, a hierarchical, sparsity-inducing prior that partitions a variance budget across weight groups via a Dirichlet allocation. The DSM framework provides global, group, and local shrinkage, connecting to horseshoe-type priors while enabling structured sparsity aligned with neural network architecture. The authors derive theoretical properties of the induced dependence and shrinkage, linearize single-hidden-layer networks to study shrinkage directions, and demonstrate through linear and real data experiments that DSM priors yield sparser models with competitive predictive accuracy and greater pruning robustness, particularly in small-to-moderate data regimes. The approach also shows robustness against adversarial perturbations and offers a principled mechanism to counter cold posterior effects via heavy-tailed shrinkage, making it a compelling option for calibrated, efficient BNNs.

Abstract

Neural networks are the cornerstone of modern machine learning, yet can be difficult to interpret, give overconfident predictions and are vulnerable to adversarial attacks. Bayesian neural networks (BNNs) provide some alleviation of these limitations, but have problems of their own. The key step of specifying prior distributions in BNNs is no trivial task, yet is often skipped out of convenience. In this work, we propose a new class of prior distributions for BNNs, the Dirichlet scale mixture (DSM) prior, that addresses current limitations in Bayesian neural networks through structured, sparsity-inducing shrinkage. Theoretically, we derive general dependence structures and shrinkage results for DSM priors and show how they manifest under the geometry induced by neural networks. In experiments on simulated and real world data we find that the DSM priors encourages sparse networks through implicit feature selection, show robustness under adversarial attacks and deliver competitive predictive performance with substantially fewer effective parameters. In particular, their advantages appear most pronounced in correlated, moderately small data regimes, and are more amenable to weight pruning. Moreover, by adopting heavy-tailed shrinkage mechanisms, our approach aligns with recent findings that such priors can mitigate the cold posterior effect, offering a principled alternative to the commonly used Gaussian priors.
Paper Structure (17 sections, 4 theorems, 119 equations, 20 figures, 5 tables)

This paper contains 17 sections, 4 theorems, 119 equations, 20 figures, 5 tables.

Key Result

Theorem 4.1

Let $w_j$ follow the DSM prior with global scale $\tau$, group scale $\lambda_j \sim t^+_{\nu}(0, 1)$ and local scale $\xi_j \sim \mathrm{Beta}(\alpha, (p-1)\alpha)$ marginally. Assume a$z_j =\sqrt{n}\sigma^{-1}\tau s_j > 0$ to be fixed and given. The marginal prior distribution of $\kappa_j$ as per where $\tilde{C}(\nu, z_j) = \frac{\Gamma(\frac{\nu+1}{2})}{\sqrt{\nu\pi}\Gamma(\frac{\nu}{2})} \n

Figures (20)

  • Figure 1: Monte Carlo estimates of the dispersion ratio $CV^2(\tilde{\lambda}_j^{\,2})$ for different scale priors as a function of their tail-controlling parameter $(\sigma, k, a, b, \nu)$. The curves are evaluated at three fixed values of the regularization parameter: (i) the median of the prior on $c$, (ii) the $0.9$ quantile of the prior on $c^2$, and (iii) a very large value of $c$, corresponding to an essentially unregularized regime. The horizontal line indicates the threshold $1/(p\alpha)$ separating negative and positive covariance regimes. Smaller values of $c$ attenuate dispersion and increasingly favor negative covariance among the variance components $\tilde{\lambda}_j^{\,2}\xi_{jk}$.
  • Figure 2: $p(\kappa \mid \sigma, \tau)$ in the Dirichlet horseshoe ($\nu=1$), Dirichlet Student’s t (\ref{['thm:theorem_2']}), and the classical horseshoe (\ref{['eq:kappa_prior']}). $\kappa=1$ indicates full shrinkage, and $\kappa=0$ indicates no shrinkage at all. It is clear that the Dirichlet methods shrinks more aggressively, as the shrinkage factor $\kappa$ has more mass close to $1$, than for the horseshoe.
  • Figure 3: Boxplot of posterior samples from $(w_1, w_5, w_6)$ for the linear regression model, for different correlations. The dotted blue line represents the underlying, true coefficient.
  • Figure 4: Estimated density of posterior samples for $w_1, w_5, w_6$ (left) and histogram of $\kappa_1, \kappa_5, \kappa_6$ (right) for the linear regression model, for $\rho=0.9$. The dashed black line represents the true coefficient, and the dotted purple line the GLS estimate.
  • Figure 5: Boxplots of seed-level median CRPS for each model and training sample size on the independent Friedman (left) and correlated Friedman (right) datasets. For each training size $N \in \{100,200,500\}$, five independent datasets are used. Each box summarizes the five median CRPS values, where each value is computed from posterior predictive ensembles evaluated on a large generated test set.
  • ...and 15 more figures

Theorems & Definitions (6)

  • Theorem 4.1
  • Definition 5.1
  • Definition 5.2
  • Lemma 7.1: Expectation of transformed Beta variable I
  • Lemma 7.2: Expectation of transformed Beta variable II
  • Lemma 7.3: A priori distribution of shrinkage factor for student T local scale