Training independent subnetworks for robust prediction
Marton Havasi, Rodolphe Jenatton, Stanislav Fort, Jeremiah Zhe Liu, Jasper Snoek, Balaji Lakshminarayanan, Andrew M. Dai, Dustin Tran
TL;DR
Uncertainty estimation and robustness to distribution shifts remain challenging with high-cost ensembles. The authors introduce MIMO, a multi-input multi-output architecture that trains $M$ independent subnetworks inside a single network by concatenating $M$ inputs and producing $M$ outputs, and at test time averages the $M$ predictions from tiled inputs: $p_\theta({\mathbf{y}}'|{\mathbf{x}}')=\frac{1}{M}\sum_{m=1}^M p_\theta({\mathbf{y}}_m|{\mathbf{x}}',\dots,{\mathbf{x}}')$. The subnetworks converge to distinct optima and exhibit diversity comparable to independently trained ensembles, enabling robust predictions with only a negligible increase in parameters and compute. Empirical results on CIFAR-10, CIFAR-100, ImageNet, and OOD variants show improved negative log-likelihood, accuracy, and calibration, approaching deep ensembles while maintaining single-pass efficiency. The work provides actionable guidance on selecting the number of subnetworks and leveraging input or batch repetition to maximize performance, making robust, ensemble-like prediction more scalable in practice.
Abstract
Recent approaches to efficiently ensemble neural networks have shown that strong robustness and uncertainty performance can be achieved with a negligible gain in parameters over the original network. However, these methods still require multiple forward passes for prediction, leading to a significant computational cost. In this work, we show a surprising result: the benefits of using multiple predictions can be achieved `for free' under a single model's forward pass. In particular, we show that, using a multi-input multi-output (MIMO) configuration, one can utilize a single model's capacity to train multiple subnetworks that independently learn the task at hand. By ensembling the predictions made by the subnetworks, we improve model robustness without increasing compute. We observe a significant improvement in negative log-likelihood, accuracy, and calibration error on CIFAR10, CIFAR100, ImageNet, and their out-of-distribution variants compared to previous methods.
