Table of Contents
Fetching ...

BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning

Yeming Wen, Dustin Tran, Jimmy Ba

TL;DR

The paper tackles the high computational and memory costs of traditional ensembles by proposing BatchEnsemble, a parameter-efficient ensemble that reuses a shared weight and injects per-member rank-1 fast weights via a Hadamard product.The method enables both intra- and inter-device parallelism, achieving substantial test-time speedups and memory reductions while maintaining competitive accuracy and calibrated uncertainty across vision and language tasks.It extends naturally to lifelong learning, allowing up to 100 sequential tasks with no forgetting and without maintaining large task-specific networks.Extensive experiments on CIFAR, WMT translation, and calibration benchmarks demonstrate that BatchEnsemble offers the best accuracy–diversity–efficiency trade-off among efficient ensembles and scales to large, sequential learning settings.

Abstract

Ensembles, where multiple neural networks are trained individually and their predictions are averaged, have been shown to be widely successful for improving both the accuracy and predictive uncertainty of single neural networks. However, an ensemble's cost for both training and testing increases linearly with the number of networks, which quickly becomes untenable. In this paper, we propose BatchEnsemble, an ensemble method whose computational and memory costs are significantly lower than typical ensembles. BatchEnsemble achieves this by defining each weight matrix to be the Hadamard product of a shared weight among all ensemble members and a rank-one matrix per member. Unlike ensembles, BatchEnsemble is not only parallelizable across devices, where one device trains one member, but also parallelizable within a device, where multiple ensemble members are updated simultaneously for a given mini-batch. Across CIFAR-10, CIFAR-100, WMT14 EN-DE/EN-FR translation, and out-of-distribution tasks, BatchEnsemble yields competitive accuracy and uncertainties as typical ensembles; the speedup at test time is 3X and memory reduction is 3X at an ensemble of size 4. We also apply BatchEnsemble to lifelong learning, where on Split-CIFAR-100, BatchEnsemble yields comparable performance to progressive neural networks while having a much lower computational and memory costs. We further show that BatchEnsemble can easily scale up to lifelong learning on Split-ImageNet which involves 100 sequential learning tasks.

BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning

TL;DR

The paper tackles the high computational and memory costs of traditional ensembles by proposing BatchEnsemble, a parameter-efficient ensemble that reuses a shared weight and injects per-member rank-1 fast weights via a Hadamard product.The method enables both intra- and inter-device parallelism, achieving substantial test-time speedups and memory reductions while maintaining competitive accuracy and calibrated uncertainty across vision and language tasks.It extends naturally to lifelong learning, allowing up to 100 sequential tasks with no forgetting and without maintaining large task-specific networks.Extensive experiments on CIFAR, WMT translation, and calibration benchmarks demonstrate that BatchEnsemble offers the best accuracy–diversity–efficiency trade-off among efficient ensembles and scales to large, sequential learning settings.

Abstract

Ensembles, where multiple neural networks are trained individually and their predictions are averaged, have been shown to be widely successful for improving both the accuracy and predictive uncertainty of single neural networks. However, an ensemble's cost for both training and testing increases linearly with the number of networks, which quickly becomes untenable. In this paper, we propose BatchEnsemble, an ensemble method whose computational and memory costs are significantly lower than typical ensembles. BatchEnsemble achieves this by defining each weight matrix to be the Hadamard product of a shared weight among all ensemble members and a rank-one matrix per member. Unlike ensembles, BatchEnsemble is not only parallelizable across devices, where one device trains one member, but also parallelizable within a device, where multiple ensemble members are updated simultaneously for a given mini-batch. Across CIFAR-10, CIFAR-100, WMT14 EN-DE/EN-FR translation, and out-of-distribution tasks, BatchEnsemble yields competitive accuracy and uncertainties as typical ensembles; the speedup at test time is 3X and memory reduction is 3X at an ensemble of size 4. We also apply BatchEnsemble to lifelong learning, where on Split-CIFAR-100, BatchEnsemble yields comparable performance to progressive neural networks while having a much lower computational and memory costs. We further show that BatchEnsemble can easily scale up to lifelong learning on Split-ImageNet which involves 100 sequential learning tasks.

Paper Structure

This paper contains 25 sections, 6 equations, 9 figures, 6 tables.

Figures (9)

  • Figure 1: The test time cost (blue) and memory cost of BatchEnsemble (orange) w.r.t the ensemble size. The result is relative to single model cost. Testing time cost and memory cost of naive ensemble are plotted in green.
  • Figure 2: An illustration on how to generate the ensemble weights for two ensemble members.
  • Figure 3: Performance for lifelong learning. (a): Validation accuracy for each Split-ImageNet task. Standard deviation is computed over 5 random seeds. (b): BatchEnsemble and several other methods on Split-CIFAR100. BatchEnsemble achieves the best trade-off among Accuracy ($\uparrow$), Forget ($\downarrow$), and Time & Memory ($\downarrow$) costs. VAN: Vanilla neural network. EWC: Elastic weight consolidation Kirkpatrick2016OvercomingCF. PNN: Progressive neural network Rusu2016ProgressiveNN. BN-Tuned: Fine tuning Batch Norm layer per subsequent tasks. BatchE: BatchEnsemble. Upperbound: Individual ResNet-50 per task.
  • Figure 4: Comparison between BatchEnsemble and single model on WMT English-German and English-French. Training stops after the model reaches targeted validation perplexity. BatchEnsemble gives a faster convergence by taking the advantage of multiple models. (a): Validation loss of WMT16 English-German task. (b): Validation loss of WMT14 English-French task. Big: Tranformer big model. Base: Transformer base model. BE: BatchEnsemble. Single: Single model.
  • Figure 5: Calibration on CIFAR-10 corruptions: boxplots showing a comparison of ECE under all types of corruptions on CIFAR-10. Each box shows the quartiles summarizing the results across all types of skew while the error bars indicate the min and max across different skew types. Ensemble/BatchEnsemble: Naive/Batch ensemble of 4 ResNet32x4 models. Dropout-8: Dropout ensemble with sample size 8. BEDrop-8: BatchEnsemble of 4 models + Dropout ensemble with sample size 8. A similar measurement can be found in Ovadia2019CanYT.
  • ...and 4 more figures