Bayesian Deep Learning and a Probabilistic Perspective of Generalization
Andrew Gordon Wilson, Pavel Izmailov
TL;DR
The paper reframes generalization in deep learning as a problem of Bayesian marginalization over weight configurations, arguing that the predictive distribution should average over plausible models rather than fixate on a single set of weights. It shows that deep ensembles approximate Bayesian model averaging, and introduces MultiSWAG to marginalize across multiple basins of attraction, yielding superior accuracy and calibration and mitigating double descent. By analyzing priors over functions and connections to Gaussian processes, the work explains mysterious generalization phenomena and demonstrates that random-label fits can occur within flexible priors, while still favoring clean data under marginal likelihood. The authors also explore tempering as a principled tool for handling model misspecification, data augmentation effects, and inexact inference, offering practical guidelines for Bayesian deep learning that improve predictive uncertainty and scalability.
Abstract
The key distinguishing property of a Bayesian approach is marginalization, rather than using a single setting of weights. Bayesian marginalization can particularly improve the accuracy and calibration of modern deep neural networks, which are typically underspecified by the data, and can represent many compelling but different solutions. We show that deep ensembles provide an effective mechanism for approximate Bayesian marginalization, and propose a related approach that further improves the predictive distribution by marginalizing within basins of attraction, without significant overhead. We also investigate the prior over functions implied by a vague distribution over neural network weights, explaining the generalization properties of such models from a probabilistic perspective. From this perspective, we explain results that have been presented as mysterious and distinct to neural network generalization, such as the ability to fit images with random labels, and show that these results can be reproduced with Gaussian processes. We also show that Bayesian model averaging alleviates double descent, resulting in monotonic performance improvements with increased flexibility. Finally, we provide a Bayesian perspective on tempering for calibrating predictive distributions.
