Table of Contents
Fetching ...

Bayesian Neural Networks with Domain Knowledge Priors

Dylan Sam, Rattana Pukdee, Daniel P. Jeong, Yewon Byun, J. Zico Kolter

TL;DR

This work proposes a framework for integrating general forms of domain knowledge into a BNN prior through variational inference, while enabling computationally efficient posterior inference and sampling, and shows that BNNs using the proposed domain knowledge priors outperform those with standard priors.

Abstract

Bayesian neural networks (BNNs) have recently gained popularity due to their ability to quantify model uncertainty. However, specifying a prior for BNNs that captures relevant domain knowledge is often extremely challenging. In this work, we propose a framework for integrating general forms of domain knowledge (i.e., any knowledge that can be represented by a loss function) into a BNN prior through variational inference, while enabling computationally efficient posterior inference and sampling. Specifically, our approach results in a prior over neural network weights that assigns high probability mass to models that better align with our domain knowledge, leading to posterior samples that also exhibit this behavior. We show that BNNs using our proposed domain knowledge priors outperform those with standard priors (e.g., isotropic Gaussian, Gaussian process), successfully incorporating diverse types of prior information such as fairness, physics rules, and healthcare knowledge and achieving better predictive performance. We also present techniques for transferring the learned priors across different model architectures, demonstrating their broad utility across various settings.

Bayesian Neural Networks with Domain Knowledge Priors

TL;DR

This work proposes a framework for integrating general forms of domain knowledge into a BNN prior through variational inference, while enabling computationally efficient posterior inference and sampling, and shows that BNNs using the proposed domain knowledge priors outperform those with standard priors.

Abstract

Bayesian neural networks (BNNs) have recently gained popularity due to their ability to quantify model uncertainty. However, specifying a prior for BNNs that captures relevant domain knowledge is often extremely challenging. In this work, we propose a framework for integrating general forms of domain knowledge (i.e., any knowledge that can be represented by a loss function) into a BNN prior through variational inference, while enabling computationally efficient posterior inference and sampling. Specifically, our approach results in a prior over neural network weights that assigns high probability mass to models that better align with our domain knowledge, leading to posterior samples that also exhibit this behavior. We show that BNNs using our proposed domain knowledge priors outperform those with standard priors (e.g., isotropic Gaussian, Gaussian process), successfully incorporating diverse types of prior information such as fairness, physics rules, and healthcare knowledge and achieving better predictive performance. We also present techniques for transferring the learned priors across different model architectures, demonstrating their broad utility across various settings.
Paper Structure (42 sections, 14 equations, 5 figures, 8 tables)

This paper contains 42 sections, 14 equations, 5 figures, 8 tables.

Figures (5)

  • Figure 1: Our framework (Banana; bottom row) compared to standard practice (top row) for training BNNs. Our method incorporates domain knowledge via a loss function $\phi$ to learn an intermediate step of an informative prior via a variational objective. The informative prior helps encourage models that exhibit desirable behavior.
  • Figure 2: Comparison of samples from the Banana prior (in green) against those from an isotropic Gaussian prior (in red) on the Folktables dataset. The solid line is the Pareto frontier, and the dots represent samples from each prior. Samples from Banana tend to be more Pareto-optimal, achieving higher accuracy while satisfying fairness constraints.
  • Figure 3: Change in test accuracy on the DecoyMNIST task when varying the number of mixture components in the informative prior in Banana. Results are averaged over 5 seeds, and the shaded region represents mean $\pm$ s.e.
  • Figure 4: Change in test accuracy on the DecoyMNIST task when varying the number of models sampled to compute our posterior average in Banana. Results are averaged over 5 seeds, and the shaded region represents mean $\pm$ s.e.
  • Figure A1: Results when varying the rank to approximate our informative prior in Banana on the DecoyMNIST task. Results are averaged over 5 seeds, and the shaded region represents mean $\pm$ s.e.

Theorems & Definitions (1)

  • Definition 1