Table of Contents
Fetching ...

Evaluating Prediction-Time Batch Normalization for Robustness under Covariate Shift

Zachary Nado, Shreyas Padhy, D. Sculley, Alexander D'Amour, Balaji Lakshminarayanan, Jasper Snoek

TL;DR

This work tackles miscalibration and accuracy degradation caused by covariate shift by introducing prediction-time batch normalization, which recalculates BN statistics on small unlabeled prediction-time batches. The approach is simple and computationally efficient, yielding strong calibration and accuracy improvements (notably an mCE of 60.28% on ImageNet-C) and complementing deep ensembles. Through extensive ablations and cross-domain experiments (CIFAR-10-C, ImageNet-C, Criteo), the paper elucidates the mechanism as activation distribution alignment and avoidance of regions of high uncertainty, while also identifying limitations with pretraining and some natural shifts. The findings offer a practical, low-overhead tool for robustness in real-world deployments, with clear directions for future work on pretraining interactions and broader shift types.

Abstract

Covariate shift has been shown to sharply degrade both predictive accuracy and the calibration of uncertainty estimates for deep learning models. This is worrying, because covariate shift is prevalent in a wide range of real world deployment settings. However, in this paper, we note that frequently there exists the potential to access small unlabeled batches of the shifted data just before prediction time. This interesting observation enables a simple but surprisingly effective method which we call prediction-time batch normalization, which significantly improves model accuracy and calibration under covariate shift. Using this one line code change, we achieve state-of-the-art on recent covariate shift benchmarks and an mCE of 60.28\% on the challenging ImageNet-C dataset; to our knowledge, this is the best result for any model that does not incorporate additional data augmentation or modification of the training pipeline. We show that prediction-time batch normalization provides complementary benefits to existing state-of-the-art approaches for improving robustness (e.g. deep ensembles) and combining the two further improves performance. Our findings are supported by detailed measurements of the effect of this strategy on model behavior across rigorous ablations on various dataset modalities. However, the method has mixed results when used alongside pre-training, and does not seem to perform as well under more natural types of dataset shift, and is therefore worthy of additional study. We include links to the data in our figures to improve reproducibility, including a Python notebooks that can be run to easily modify our analysis at https://colab.research.google.com/drive/11N0wDZnMQQuLrRwRoumDCrhSaIhkqjof.

Evaluating Prediction-Time Batch Normalization for Robustness under Covariate Shift

TL;DR

This work tackles miscalibration and accuracy degradation caused by covariate shift by introducing prediction-time batch normalization, which recalculates BN statistics on small unlabeled prediction-time batches. The approach is simple and computationally efficient, yielding strong calibration and accuracy improvements (notably an mCE of 60.28% on ImageNet-C) and complementing deep ensembles. Through extensive ablations and cross-domain experiments (CIFAR-10-C, ImageNet-C, Criteo), the paper elucidates the mechanism as activation distribution alignment and avoidance of regions of high uncertainty, while also identifying limitations with pretraining and some natural shifts. The findings offer a practical, low-overhead tool for robustness in real-world deployments, with clear directions for future work on pretraining interactions and broader shift types.

Abstract

Covariate shift has been shown to sharply degrade both predictive accuracy and the calibration of uncertainty estimates for deep learning models. This is worrying, because covariate shift is prevalent in a wide range of real world deployment settings. However, in this paper, we note that frequently there exists the potential to access small unlabeled batches of the shifted data just before prediction time. This interesting observation enables a simple but surprisingly effective method which we call prediction-time batch normalization, which significantly improves model accuracy and calibration under covariate shift. Using this one line code change, we achieve state-of-the-art on recent covariate shift benchmarks and an mCE of 60.28\% on the challenging ImageNet-C dataset; to our knowledge, this is the best result for any model that does not incorporate additional data augmentation or modification of the training pipeline. We show that prediction-time batch normalization provides complementary benefits to existing state-of-the-art approaches for improving robustness (e.g. deep ensembles) and combining the two further improves performance. Our findings are supported by detailed measurements of the effect of this strategy on model behavior across rigorous ablations on various dataset modalities. However, the method has mixed results when used alongside pre-training, and does not seem to perform as well under more natural types of dataset shift, and is therefore worthy of additional study. We include links to the data in our figures to improve reproducibility, including a Python notebooks that can be run to easily modify our analysis at https://colab.research.google.com/drive/11N0wDZnMQQuLrRwRoumDCrhSaIhkqjof.

Paper Structure

This paper contains 26 sections, 2 equations, 27 figures, 3 tables.

Figures (27)

  • Figure 1: Empirical distributions for the output of selected normalization layers in Resnet-20 on CIFAR10 and CIFAR10-C. Activations are averaged over spatial dimensions, resulting in one distribution per output channel. The activations are recorded immediately after the batch normalization layer, before the non-linearity of each layer. The blue and red curves are aggregated across all shifted examples, while the yellow is across all training examples. We can clearly see that prediction-time BN is much more effective at aligning the shifted activations with the training distribution support and shape. These layers were picked as representative examples of activations of all normalization layers in the model, we encourage the reader to check Figures \ref{['fig:cifar10_bn_activation_dists_ema_all']}, \ref{['fig:cifar10_bn_activation_dists_test_batch_all']} for all layers.
  • Figure 2: Brier scores of predictions become higher when the activations from the training and test sets occur in increasingly distinct regions. Here, we summarize how the distributions of penultimate hidden layer activations ${\bm{h}} = g\left({\bm{x}} \right)$ on shifted test sets compare to their distributions on the training set, under a number of different normalization schemes. Each point represents a type of shift, where the color indicates the intensity of the shift applied. On the horizontal axis, we plot a measure of the discrepancy between the training and test distributions of activations, $p({\bm{h}})$ and $q({\bm{h}})$, respectively by approximating $KL(p({\bm{h}}) \| q({\bm{h}})) \approx T^{-1}\sum_{i=1}^T \ln\frac{\hat{q}({\bm{h}}_i)}{\hat{p}({\bm{h}}_i)}$, where the summation is taken over test instances, and $\hat{p}$ and $\hat{q}$ are multivariate normal densities whose means and variances match $p$ and $q$, respectively. We use KL divergence because it is particularly sensitive to cases where test activations lie outside of the effective support of training activations. On the vertical axes are the Brier scores for each shifted example, averaged within each split. In addition to the overall trend of increasing discrepancy leading to decreasing performance, we also see that higher shift intensities tend to have higher support mismatch. Intuitively, more shift in the inputs should lead to more shift in the activation values. We can clearly see that only when using prediction-time BN to compute $g\left({\bm{x}}\right)$ are the activations more closely aligned and the Brier Scores more consistently lower. See Figure \ref{['fig:cifar10_likelihood_ratio_scatter_all']} for similar trends in accuracy.
  • Figure 3: Calibration across CIFAR10-C, ImageNet-C, and Criteo across increasing levels of dataset shift. The box plots show the median, quartiles, minimum, and maximum performance per method.
  • Figure 4: CIFAR-10-C Brier Score at shift level 5 for different prediction batch sizes. We see that relatively small batch sizes are required to effectively correct for covariate shift, with only marginal improvements after a 100 examples. See Figure \ref{['fig:cifar_test_bs_all']} for how other metrics perform across prediction batch sizes, and Figure \ref{['fig:imagenet_test_bs_all']} for a similar trend on ImageNet-C.
  • Figure 5: Calibration on CIFAR-10-C (top) and ImageNet-C (bottom), both with a prediction batch size of 500. Here we clearly see that just having access to a single batch from each split (frozen Prediction BN) is sufficient to get substantial performance improvements. Also, while prediction-time BN is sensitive to multiple simultaneous types of covariate shift, it still outperforms train BN. See Figures \ref{['fig:cifar_bn_stat_method_all']}, \ref{['fig:imagenet_bn_stat_method_all']} for the accuracy and Brier score performance.
  • ...and 22 more figures