Table of Contents
Fetching ...

Scalable Bayesian Learning with posteriors

Samuel Duffield, Kaelan Donatella, Johnathan Chiu, Phoebe Klett, Daniel Simpson

TL;DR

This work tackles the computational challenge of Bayesian learning in modern, large-scale models by introducing posteriors, a minibatch-first PyTorch library that unifies and scales uncertainty quantification across Laplace, VI, and SGMCMC methods. By formalizing SGD as a low-temperature limit of SGMCMC, the authors enable a seamless transition to Bayesian sampling and show that parallel SGMCMC can function as Bayesian deep ensembles for scalable, diverse posterior approximations. Through experiments on cold posterior effects, continual learning with LoRA, and Bayesian fine-tuning of Llama 3, the paper demonstrates that Bayesian approaches can improve generalization, online learning without forgetting, and out-of-distribution detection, while also highlighting limitations of Gaussian approximations in some regimes. The proposed framework and open-source tooling offer a practical path to integrating principled uncertainty into large-language-model workflows, with potential impact on reliability and interpretability in production ML systems.

Abstract

Although theoretically compelling, Bayesian learning with modern machine learning models is computationally challenging since it requires approximating a high dimensional posterior distribution. In this work, we (i) introduce posteriors, an easily extensible PyTorch library hosting general-purpose implementations making Bayesian learning accessible and scalable to large data and parameter regimes; (ii) present a tempered framing of stochastic gradient Markov chain Monte Carlo, as implemented in posteriors, that transitions seamlessly into optimization and unveils a minor modification to deep ensembles to ensure they are asymptotically unbiased for the Bayesian posterior, and (iii) demonstrate and compare the utility of Bayesian approximations through experiments including an investigation into the cold posterior effect and applications with large language models.

Scalable Bayesian Learning with posteriors

TL;DR

This work tackles the computational challenge of Bayesian learning in modern, large-scale models by introducing posteriors, a minibatch-first PyTorch library that unifies and scales uncertainty quantification across Laplace, VI, and SGMCMC methods. By formalizing SGD as a low-temperature limit of SGMCMC, the authors enable a seamless transition to Bayesian sampling and show that parallel SGMCMC can function as Bayesian deep ensembles for scalable, diverse posterior approximations. Through experiments on cold posterior effects, continual learning with LoRA, and Bayesian fine-tuning of Llama 3, the paper demonstrates that Bayesian approaches can improve generalization, online learning without forgetting, and out-of-distribution detection, while also highlighting limitations of Gaussian approximations in some regimes. The proposed framework and open-source tooling offer a practical path to integrating principled uncertainty into large-language-model workflows, with potential impact on reliability and interpretability in production ML systems.

Abstract

Although theoretically compelling, Bayesian learning with modern machine learning models is computationally challenging since it requires approximating a high dimensional posterior distribution. In this work, we (i) introduce posteriors, an easily extensible PyTorch library hosting general-purpose implementations making Bayesian learning accessible and scalable to large data and parameter regimes; (ii) present a tempered framing of stochastic gradient Markov chain Monte Carlo, as implemented in posteriors, that transitions seamlessly into optimization and unveils a minor modification to deep ensembles to ensure they are asymptotically unbiased for the Bayesian posterior, and (iii) demonstrate and compare the utility of Bayesian approximations through experiments including an investigation into the cold posterior effect and applications with large language models.
Paper Structure (28 sections, 29 equations, 12 figures, 3 tables)

This paper contains 28 sections, 29 equations, 12 figures, 3 tables.

Figures (12)

  • Figure 1: Pictorial representation of the benefits of Bayesian learning. Left: Averaging over multiple plausible fits to the data improves out-of-distribution generalisation. Center: Adapting to new online data without forgetting previous data. Right: Decomposing predictive uncertainty, with epistemic uncertainty providing an improved indicator for out-of-distribution detection.
  • Figure 2: Trajectories of various sampling methods for a toy multimodal posterior. Deep ensemble concentrates on the modes, serial SGMCMC struggles to transfer between modes, parallel SGMCMC combines the benefits of both.
  • Figure 3: posteriorscode snippet to train a classifier with variational inference. posteriors recommends normalising the log posterior across the batch with scale independent of batch size or $N$. Scaling the prior and temperature by $N^{-1}$ ensures the posterior is still correctly targeted.
  • Figure 4: Investigation into the cold-posterior effect for a range of posteriors algorithms for a CNN-LSTM model wenzel2020good on the IMDB dataset maas2011learning. Non-linearized Laplace EF and Laplace GGN are indistinguishable (left panel). All Gaussian approximations are diagonal. All approaches display error bars with one standard deviation over 5 random seeds.
  • Figure 5: Continual learning with Llama 2. The online SGD and Laplace methods train one book after another, whilst the Offline SGD approach sees all books simultaneously, representing the network's capacity. Vertical dashed lines represent episode changes.
  • ...and 7 more figures