Scalable Bayesian Learning with posteriors
Samuel Duffield, Kaelan Donatella, Johnathan Chiu, Phoebe Klett, Daniel Simpson
TL;DR
This work tackles the computational challenge of Bayesian learning in modern, large-scale models by introducing posteriors, a minibatch-first PyTorch library that unifies and scales uncertainty quantification across Laplace, VI, and SGMCMC methods. By formalizing SGD as a low-temperature limit of SGMCMC, the authors enable a seamless transition to Bayesian sampling and show that parallel SGMCMC can function as Bayesian deep ensembles for scalable, diverse posterior approximations. Through experiments on cold posterior effects, continual learning with LoRA, and Bayesian fine-tuning of Llama 3, the paper demonstrates that Bayesian approaches can improve generalization, online learning without forgetting, and out-of-distribution detection, while also highlighting limitations of Gaussian approximations in some regimes. The proposed framework and open-source tooling offer a practical path to integrating principled uncertainty into large-language-model workflows, with potential impact on reliability and interpretability in production ML systems.
Abstract
Although theoretically compelling, Bayesian learning with modern machine learning models is computationally challenging since it requires approximating a high dimensional posterior distribution. In this work, we (i) introduce posteriors, an easily extensible PyTorch library hosting general-purpose implementations making Bayesian learning accessible and scalable to large data and parameter regimes; (ii) present a tempered framing of stochastic gradient Markov chain Monte Carlo, as implemented in posteriors, that transitions seamlessly into optimization and unveils a minor modification to deep ensembles to ensure they are asymptotically unbiased for the Bayesian posterior, and (iii) demonstrate and compare the utility of Bayesian approximations through experiments including an investigation into the cold posterior effect and applications with large language models.
