Table of Contents
Fetching ...

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

Andrew Gordon Wilson, Pavel Izmailov

TL;DR

The paper reframes generalization in deep learning as a problem of Bayesian marginalization over weight configurations, arguing that the predictive distribution should average over plausible models rather than fixate on a single set of weights. It shows that deep ensembles approximate Bayesian model averaging, and introduces MultiSWAG to marginalize across multiple basins of attraction, yielding superior accuracy and calibration and mitigating double descent. By analyzing priors over functions and connections to Gaussian processes, the work explains mysterious generalization phenomena and demonstrates that random-label fits can occur within flexible priors, while still favoring clean data under marginal likelihood. The authors also explore tempering as a principled tool for handling model misspecification, data augmentation effects, and inexact inference, offering practical guidelines for Bayesian deep learning that improve predictive uncertainty and scalability.

Abstract

The key distinguishing property of a Bayesian approach is marginalization, rather than using a single setting of weights. Bayesian marginalization can particularly improve the accuracy and calibration of modern deep neural networks, which are typically underspecified by the data, and can represent many compelling but different solutions. We show that deep ensembles provide an effective mechanism for approximate Bayesian marginalization, and propose a related approach that further improves the predictive distribution by marginalizing within basins of attraction, without significant overhead. We also investigate the prior over functions implied by a vague distribution over neural network weights, explaining the generalization properties of such models from a probabilistic perspective. From this perspective, we explain results that have been presented as mysterious and distinct to neural network generalization, such as the ability to fit images with random labels, and show that these results can be reproduced with Gaussian processes. We also show that Bayesian model averaging alleviates double descent, resulting in monotonic performance improvements with increased flexibility. Finally, we provide a Bayesian perspective on tempering for calibrating predictive distributions.

Bayesian Deep Learning and a Probabilistic Perspective of Generalization

TL;DR

The paper reframes generalization in deep learning as a problem of Bayesian marginalization over weight configurations, arguing that the predictive distribution should average over plausible models rather than fixate on a single set of weights. It shows that deep ensembles approximate Bayesian model averaging, and introduces MultiSWAG to marginalize across multiple basins of attraction, yielding superior accuracy and calibration and mitigating double descent. By analyzing priors over functions and connections to Gaussian processes, the work explains mysterious generalization phenomena and demonstrates that random-label fits can occur within flexible priors, while still favoring clean data under marginal likelihood. The authors also explore tempering as a principled tool for handling model misspecification, data augmentation effects, and inexact inference, offering practical guidelines for Bayesian deep learning that improve predictive uncertainty and scalability.

Abstract

The key distinguishing property of a Bayesian approach is marginalization, rather than using a single setting of weights. Bayesian marginalization can particularly improve the accuracy and calibration of modern deep neural networks, which are typically underspecified by the data, and can represent many compelling but different solutions. We show that deep ensembles provide an effective mechanism for approximate Bayesian marginalization, and propose a related approach that further improves the predictive distribution by marginalizing within basins of attraction, without significant overhead. We also investigate the prior over functions implied by a vague distribution over neural network weights, explaining the generalization properties of such models from a probabilistic perspective. From this perspective, we explain results that have been presented as mysterious and distinct to neural network generalization, such as the ability to fit images with random labels, and show that these results can be reproduced with Gaussian processes. We also show that Bayesian model averaging alleviates double descent, resulting in monotonic performance improvements with increased flexibility. Finally, we provide a Bayesian perspective on tempering for calibrating predictive distributions.

Paper Structure

This paper contains 44 sections, 2 theorems, 24 equations, 19 figures.

Key Result

Proposition 1

Suppose the network has no bias vectors, i.e. $\beta_1 = \ldots = \beta_n = 0$. Then the scales $\alpha_i$ of the prior distribution over the weights only affect the output scale of the network.

Figures (19)

  • Figure 1: Airline passenger numbers recorded monthly.
  • Figure 2: A probabilistic perspective of generalization. (a) Ideally, a model supports a wide range of datasets, but with inductive biases that provide high prior probability to a particular class of problems being considered. Here, the CNN is preferred over the linear model and the fully-connected MLP for CIFAR-10 (while we do not consider MLP models to in general have poor inductive biases, here we are considering a hypothetical example involving images and a very large MLP). (b) By representing a large hypothesis space, a model can contract around a true solution, which in the real-world is often very sophisticated. (c) With truncated support, a model will converge to an erroneous solution. (d) Even if the hypothesis space contains the truth, a model will not efficiently contract unless it also has reasonable inductive biases.
  • Figure 3: Approximating the BMA.$p(y|x,\mathcal{D}) = \int p(y|x,w)p(w | \mathcal{D}) dw$. Top:$p(w|\mathcal{D})$, with representations from VI (orange) deep ensembles (blue), MultiSWAG (red). Middle:$p(y|x,w)$ as a function of $w$ for a test input $x$. This function does not vary much within modes, but changes significantly between modes. Bottom: Distance between the true predictive distribution and the approximation, as a function of representing a posterior at an additional point $w$, assuming we have sampled the mode in dark green. There is more to be gained by exploring new basins, than continuing to explore the same basin.
  • Figure 4: Approximating the true predictive distribution.(a): A close approximation of the true predictive distribution obtained by combining $200$ HMC chains. (b): Deep ensembles predictive distribution using $50$ independently trained networks. (c): Predictive distribution for factorized variational inference (VI). (d): Convergence of the predictive distributions for deep ensembles and variational inference as a function of the number of samples; we measure the average Wasserstein distance between the marginals in the range of input positions. The multi-basin deep ensembles approach provides a more faithful approximation of the Bayesian predictive distribution than the conventional single-basin VI approach, which is overconfident between data clusters. The top panels show the Wasserstein distance between the true predictive distribution and the deep ensemble and VI approximations, as a function of inputs $x$.
  • Figure 5: Negative log likelihood for Deep Ensembles, MultiSWAG and MultiSWA using a PreResNet-20 on CIFAR-10 with varying intensity of the Gaussian blur corruption. The image in each plot shows the intensity of corruption. For all levels of intensity, MultiSWAG and MultiSWA outperform Deep Ensembles for a small number of independent models. For high levels of corruption MultiSWAG significantly outperforms other methods even for many independent models. We present results for other corruptions in the Appendix.
  • ...and 14 more figures

Theorems & Definitions (4)

  • Proposition 1
  • proof
  • Proposition 2
  • proof