Table of Contents
Fetching ...

Training independent subnetworks for robust prediction

Marton Havasi, Rodolphe Jenatton, Stanislav Fort, Jeremiah Zhe Liu, Jasper Snoek, Balaji Lakshminarayanan, Andrew M. Dai, Dustin Tran

TL;DR

Uncertainty estimation and robustness to distribution shifts remain challenging with high-cost ensembles. The authors introduce MIMO, a multi-input multi-output architecture that trains $M$ independent subnetworks inside a single network by concatenating $M$ inputs and producing $M$ outputs, and at test time averages the $M$ predictions from tiled inputs: $p_\theta({\mathbf{y}}'|{\mathbf{x}}')=\frac{1}{M}\sum_{m=1}^M p_\theta({\mathbf{y}}_m|{\mathbf{x}}',\dots,{\mathbf{x}}')$. The subnetworks converge to distinct optima and exhibit diversity comparable to independently trained ensembles, enabling robust predictions with only a negligible increase in parameters and compute. Empirical results on CIFAR-10, CIFAR-100, ImageNet, and OOD variants show improved negative log-likelihood, accuracy, and calibration, approaching deep ensembles while maintaining single-pass efficiency. The work provides actionable guidance on selecting the number of subnetworks and leveraging input or batch repetition to maximize performance, making robust, ensemble-like prediction more scalable in practice.

Abstract

Recent approaches to efficiently ensemble neural networks have shown that strong robustness and uncertainty performance can be achieved with a negligible gain in parameters over the original network. However, these methods still require multiple forward passes for prediction, leading to a significant computational cost. In this work, we show a surprising result: the benefits of using multiple predictions can be achieved `for free' under a single model's forward pass. In particular, we show that, using a multi-input multi-output (MIMO) configuration, one can utilize a single model's capacity to train multiple subnetworks that independently learn the task at hand. By ensembling the predictions made by the subnetworks, we improve model robustness without increasing compute. We observe a significant improvement in negative log-likelihood, accuracy, and calibration error on CIFAR10, CIFAR100, ImageNet, and their out-of-distribution variants compared to previous methods.

Training independent subnetworks for robust prediction

TL;DR

Uncertainty estimation and robustness to distribution shifts remain challenging with high-cost ensembles. The authors introduce MIMO, a multi-input multi-output architecture that trains independent subnetworks inside a single network by concatenating inputs and producing outputs, and at test time averages the predictions from tiled inputs: . The subnetworks converge to distinct optima and exhibit diversity comparable to independently trained ensembles, enabling robust predictions with only a negligible increase in parameters and compute. Empirical results on CIFAR-10, CIFAR-100, ImageNet, and OOD variants show improved negative log-likelihood, accuracy, and calibration, approaching deep ensembles while maintaining single-pass efficiency. The work provides actionable guidance on selecting the number of subnetworks and leveraging input or batch repetition to maximize performance, making robust, ensemble-like prediction more scalable in practice.

Abstract

Recent approaches to efficiently ensemble neural networks have shown that strong robustness and uncertainty performance can be achieved with a negligible gain in parameters over the original network. However, these methods still require multiple forward passes for prediction, leading to a significant computational cost. In this work, we show a surprising result: the benefits of using multiple predictions can be achieved `for free' under a single model's forward pass. In particular, we show that, using a multi-input multi-output (MIMO) configuration, one can utilize a single model's capacity to train multiple subnetworks that independently learn the task at hand. By ensembling the predictions made by the subnetworks, we improve model robustness without increasing compute. We observe a significant improvement in negative log-likelihood, accuracy, and calibration error on CIFAR10, CIFAR100, ImageNet, and their out-of-distribution variants compared to previous methods.

Paper Structure

This paper contains 16 sections, 4 equations, 9 figures, 4 tables, 2 algorithms.

Figures (9)

  • Figure 1: In the multi-input multi-output (MIMO) configuration, the network takes $M=3$ inputs and gives $M$ outputs. The hidden layers remain unchanged. The black connections are shared by all subnetworks, while the colored connections are for individual subnetworks. (a) During training, the inputs are independently sampled from the training set and the outputs are trained to classify their corresponding inputs. (b) During testing, the same input is repeated $M$ times and the outputs are averaged in an ensemble to obtain the final prediction.
  • Figure 2: Illustration of MIMO applied to a synthetic regression problem. (left) Example of MIMO learning $M=3$ diverse predictors. As $M$ increases, predicting with MIMO comes with a higher bias but a smaller variance (two middle panels respectively). Despite the slight increase in bias, the decrease in variance translates into an improved generalization performance (right).
  • Figure 3: Accuracy landscape and function space landscape comparison of individual subnetworks for MIMO (top row) and the naive multiheaded architecture (bottom row). (left): The test accuracy in the weight space section containing $M=3$ trained subnetworks and the origin. For the MIMO architecture, the individual subnetworks converge to three distinct low-loss basins, while naive multihead leads to the same mode. (middle-left to right): The blue, red and green panels show the disagreement between the three trained subnetworks for the same section of the weight space. For the MIMO architecture, the subnetworks often disagree, while for the naive multihead architecture they are all essentially equivalent.
  • Figure 4: Analyzing the subnetworks on the CIFAR10 dataset. (left): Histogram of the conditional variances of the pre-activations w.r.t. each input ($M=2$, ResNet28-10). (middle-left): Scatter plot of the conditional variances of the pre-activations w.r.t. each input. Almost all the pre-activations only have variance with respect to one of the inputs: the subnetwork they that are part of ($M=3$, ResNet28-10). (middle-right): Training trajectories of the subnetworks. The subnetworks converge to different local optima ($M=3$, SmallCNN). (right): Diversity of the members ($\mathcal{D}_D$) in different efficient ensemble models (ResNet 28-10).
  • Figure 5: The performance of the subnetworks and the ensemble of the subnetworks as the number of subnetworks ($M$) varies. $M=1$ is equivalent to a standard neural network (ResNet-28-10).
  • ...and 4 more figures