Table of Contents
Fetching ...

On Uniform, Bayesian, and PAC-Bayesian Deep Ensembles

Nick Hauptvogel, Christian Igel

TL;DR

This paper argues that Bayesian model averaging (BMA) is not the most effective strategy for improving deep ensemble generalization, as it neglects interactions between ensemble members and can converge to a single model. It advocates weighting ensemble members using a second-order PAC-Bayesian bound based on the tandem loss to account for pairwise model correlations, enabling robust ensembles that can safely incorporate multiple checkpoints from the same training run. Empirical results on IMDB, CIFAR-10/100, and EyePACS show that uniformly weighted deep ensembles match or surpass Bayes ensembles, while tandem-bound weighting can achieve comparable performance with nonvacuous generalization guarantees, and snapshot ensembles benefit when weighted by the tandem bound. Overall, the work demonstrates that simple, well-weighted deep ensembles can rival sophisticated Bayes ensembles, providing theoretical guarantees and practical advantages, including circumventing the need for early stopping in some settings.

Abstract

It is common practice to combine deep neural networks into ensembles. These deep ensembles can benefit from the cancellation of errors effect: Errors by ensemble members may average out, leading to better generalization performance than each individual network. Bayesian neural networks learn a posterior distribution over model parameters, and sampling and weighting networks according to this posterior yields an ensemble model referred to as a Bayes ensemble. This study reviews the argument that neither the sampling nor the weighting in Bayes ensembles are particularly well suited for increasing generalization performance, as they do not support the cancellation of errors effect. In contrast, we show that a weighted average of models, where the weights are optimized by minimizing a second-order PAC-Bayesian generalization bound, can improve generalization. It is crucial that the optimization takes correlations between models into account. This can be achieved by minimizing the tandem loss, which requires hold-out data for estimating error correlations. The tandem loss based PAC-Bayesian weighting increases robustness against correlated models and models with lower performance in an ensemble. This allows us to safely add several models from the same learning process to an ensemble, instead of using early-stopping for selecting a single weight configuration. Our experiments provide further evidence that state-of-the-art intricate Bayes ensembles do not outperform simple uniformly weighted deep ensembles in terms of classification accuracy. Additionally, we show that these Bayes ensembles cannot match the performance of deep ensembles weighted by optimizing the tandem loss, which additionally provides nonvacuous rigorous generalization guarantees.

On Uniform, Bayesian, and PAC-Bayesian Deep Ensembles

TL;DR

This paper argues that Bayesian model averaging (BMA) is not the most effective strategy for improving deep ensemble generalization, as it neglects interactions between ensemble members and can converge to a single model. It advocates weighting ensemble members using a second-order PAC-Bayesian bound based on the tandem loss to account for pairwise model correlations, enabling robust ensembles that can safely incorporate multiple checkpoints from the same training run. Empirical results on IMDB, CIFAR-10/100, and EyePACS show that uniformly weighted deep ensembles match or surpass Bayes ensembles, while tandem-bound weighting can achieve comparable performance with nonvacuous generalization guarantees, and snapshot ensembles benefit when weighted by the tandem bound. Overall, the work demonstrates that simple, well-weighted deep ensembles can rival sophisticated Bayes ensembles, providing theoretical guarantees and practical advantages, including circumventing the need for early stopping in some settings.

Abstract

It is common practice to combine deep neural networks into ensembles. These deep ensembles can benefit from the cancellation of errors effect: Errors by ensemble members may average out, leading to better generalization performance than each individual network. Bayesian neural networks learn a posterior distribution over model parameters, and sampling and weighting networks according to this posterior yields an ensemble model referred to as a Bayes ensemble. This study reviews the argument that neither the sampling nor the weighting in Bayes ensembles are particularly well suited for increasing generalization performance, as they do not support the cancellation of errors effect. In contrast, we show that a weighted average of models, where the weights are optimized by minimizing a second-order PAC-Bayesian generalization bound, can improve generalization. It is crucial that the optimization takes correlations between models into account. This can be achieved by minimizing the tandem loss, which requires hold-out data for estimating error correlations. The tandem loss based PAC-Bayesian weighting increases robustness against correlated models and models with lower performance in an ensemble. This allows us to safely add several models from the same learning process to an ensemble, instead of using early-stopping for selecting a single weight configuration. Our experiments provide further evidence that state-of-the-art intricate Bayes ensembles do not outperform simple uniformly weighted deep ensembles in terms of classification accuracy. Additionally, we show that these Bayes ensembles cannot match the performance of deep ensembles weighted by optimizing the tandem loss, which additionally provides nonvacuous rigorous generalization guarantees.
Paper Structure (27 sections, 3 theorems, 18 equations, 7 figures, 4 tables)

This paper contains 27 sections, 3 theorems, 18 equations, 7 figures, 4 tables.

Key Result

Theorem 1

For any probability distribution $\pi$ on $\mathcal{H}$ that is independent of $\mathcal{D}$ and any $\delta\in ]0,1[$, with probability at least $1-\delta$ over a random draw of $\mathcal{D}$ with $n$ elements, for all distributions $\rho$ on $\mathcal{H}$ and all $\lambda \in ]0, 2[$ simultaneousl

Figures (7)

  • Figure 1: Mean test accuracy $\hat{A}$$\pm\sigma$ vs. ensemble size $M$ over five ensembles for uniformly and PAC-Bayesian weighted deep ensembles ($\mathrm{AVG}_u$ and $\mathrm{AVG}_\rho$), using either only the last or all training checkpoints, and the best single model (SGD). References are Bayesian ensembles cSGHMC-ap wenzel_2020_how, cSGLD ashukha2021pitfalls, MC-Dropout and MC-Dropout Deep Ensemble (i.e., ensemble of Bayesian ensembles), as well as simple Deep Ensembles for EyePACS from band2021benchmarking. Numbers in brackets indicate the ensemble sizes of the baselines. Additional results for other settings are shown in Figure \ref{['fig:original_acc']} in the appendix.
  • Figure 2: IMDB accuracies ($\hat{A}$, top) and weight distribution per member for first-order (middle, lorenzen2019pac) and tandem bound weighting (bottom) for the Simple setting considering only a single network per training process (left) and when adding checkpoints (right). The number of training runs was 40. Four intermediate checkpoints were added, giving a total of five weight configurations per training process.
  • Figure 3: ResNet20 on CIFAR-10 accuracies (top) and weight distribution per member for first-order (middle) and tandem bound weighting (bottom) for the Simple setting (left) and SSE (right), both including all checkpoints (24 training processes, five checkpoints per process).
  • Figure A.4: Mean test accuracy $\hat{A}$$\pm \sigma$ over five ensembles for uniformly and PAC-Bayesian weighted deep ensembles ($\mathrm{AVG}_u$ and $\mathrm{AVG}_\rho$), using only the last or all training checkpoints, and best single model (SGD). References are Bayesian ensembles cSGHMC-ap wenzel_2020_how, cSGLD ashukha2021pitfalls, MC-Dropout and MC-Dropout Deep Ensemble (i.e., ensemble of Bayesian ensembles), as well as simple Deep Ensembles for EyePACS from band2021benchmarking. Numbers in brackets indicate the ensemble sizes.
  • Figure A.5: Snapshot ensemble (SSE) experiments: Mean ensemble accuracy (5 ensembles, $\pm\sigma$)
  • ...and 2 more figures

Theorems & Definitions (7)

  • Theorem 1: masegosa2020second
  • Theorem 2: Bernstein-von Mises theorem
  • Claim 1
  • Claim 2
  • Claim 3
  • Theorem 3
  • proof