BatchEnsemble: An Alternative Approach to Efficient Ensemble and Lifelong Learning
Yeming Wen, Dustin Tran, Jimmy Ba
TL;DR
The paper tackles the high computational and memory costs of traditional ensembles by proposing BatchEnsemble, a parameter-efficient ensemble that reuses a shared weight and injects per-member rank-1 fast weights via a Hadamard product.The method enables both intra- and inter-device parallelism, achieving substantial test-time speedups and memory reductions while maintaining competitive accuracy and calibrated uncertainty across vision and language tasks.It extends naturally to lifelong learning, allowing up to 100 sequential tasks with no forgetting and without maintaining large task-specific networks.Extensive experiments on CIFAR, WMT translation, and calibration benchmarks demonstrate that BatchEnsemble offers the best accuracy–diversity–efficiency trade-off among efficient ensembles and scales to large, sequential learning settings.
Abstract
Ensembles, where multiple neural networks are trained individually and their predictions are averaged, have been shown to be widely successful for improving both the accuracy and predictive uncertainty of single neural networks. However, an ensemble's cost for both training and testing increases linearly with the number of networks, which quickly becomes untenable. In this paper, we propose BatchEnsemble, an ensemble method whose computational and memory costs are significantly lower than typical ensembles. BatchEnsemble achieves this by defining each weight matrix to be the Hadamard product of a shared weight among all ensemble members and a rank-one matrix per member. Unlike ensembles, BatchEnsemble is not only parallelizable across devices, where one device trains one member, but also parallelizable within a device, where multiple ensemble members are updated simultaneously for a given mini-batch. Across CIFAR-10, CIFAR-100, WMT14 EN-DE/EN-FR translation, and out-of-distribution tasks, BatchEnsemble yields competitive accuracy and uncertainties as typical ensembles; the speedup at test time is 3X and memory reduction is 3X at an ensemble of size 4. We also apply BatchEnsemble to lifelong learning, where on Split-CIFAR-100, BatchEnsemble yields comparable performance to progressive neural networks while having a much lower computational and memory costs. We further show that BatchEnsemble can easily scale up to lifelong learning on Split-ImageNet which involves 100 sequential learning tasks.
