WASH: Train your Ensemble with Communication-Efficient Weight Shuffling, then Average
Louis Fournier, Adel Nabli, Masih Aminbeidokhti, Marco Pedersoli, Eugene Belilovsky, Edouard Oyallon
TL;DR
This work tackles the accuracy-inference-cost trade-off of model ensembling by introducing WASH, a communication-efficient training scheme that enables weight averaging of parallel models through parameter shuffling. By randomly permuting a small fraction of parameters across models with layer-aware adaptation, WASH preserves diversity while keeping models near a consensus, allowing high-performing averaged models with far lower communication than prior EMA-based methods. Empirically, WASH achieves state-of-the-art results on image classification tasks, with averaged models approaching ensemble performance and substantially reduced inference cost; ablations highlight the importance of early-layer shuffling and modest shuffling probabilities. The approach has practical implications for scalable, resource-efficient deployment of ensemble-like models in real-world settings, and code is released for reproducibility.
Abstract
The performance of deep neural networks is enhanced by ensemble methods, which average the output of several models. However, this comes at an increased cost at inference. Weight averaging methods aim at balancing the generalization of ensembling and the inference speed of a single model by averaging the parameters of an ensemble of models. Yet, naive averaging results in poor performance as models converge to different loss basins, and aligning the models to improve the performance of the average is challenging. Alternatively, inspired by distributed training, methods like DART and PAPA have been proposed to train several models in parallel such that they will end up in the same basin, resulting in good averaging accuracy. However, these methods either compromise ensembling accuracy or demand significant communication between models during training. In this paper, we introduce WASH, a novel distributed method for training model ensembles for weight averaging that achieves state-of-the-art image classification accuracy. WASH maintains models within the same basin by randomly shuffling a small percentage of weights during training, resulting in diverse models and lower communication costs compared to standard parameter averaging methods.
