Admissibility of Stein Shrinkage for Batch Normalization in the Presence of Adversarial Attacks
Sofia Ivolgina, P. Thomas Fletcher, Baba C. Vemuri
TL;DR
This work introduces a Stein/James–Stein shrinkage approach to batch normalization (BN) statistics, proving that shrinkage estimators dominate standard BN estimators under adversarial perturbations modeled as sub-Gaussian noise. The authors derive JS-based corrections for both BN means and BN variances (the latter via a Gamma-distribution framework) and prove dominance theorems in Gaussian and Gamma contexts. They further demonstrate substantial robustness improvements across CIFAR-10, Cityscapes, and PPMI data, including resilience to FGSM/PGD attacks, by yielding more stable BN statistics and smoother loss landscapes. The combination of theoretical guarantees and extensive experiments suggests practical benefits for improving robustness of deep networks without sacrificing accuracy, particularly in small-batch or adversarial settings.
Abstract
Batch normalization (BN) is a ubiquitous operation in deep neural networks, primarily used to improve stability and regularization during training. BN centers and scales feature maps using sample means and variances, which are naturally suited for Stein's shrinkage estimation. Applying such shrinkage yields more accurate mean and variance estimates of the batch in the mean-squared-error sense. In this paper, we prove that the Stein shrinkage estimator of the mean and variance dominates over the sample mean and variance estimators, respectively, in the presence of adversarial attacks modeled using sub-Gaussian distributions. Furthermore, by construction, the James-Stein (JS) BN yields a smaller local Lipschitz constant compared to the vanilla BN, implying better regularity properties and potentially improved robustness. This facilitates and justifies the application of Stein shrinkage to estimate the mean and variance parameters in BN and the use of it in image classification and segmentation tasks with and without adversarial attacks. We present SOTA performance results using this Stein-corrected BN in a standard ResNet architecture applied to the task of image classification using CIFAR-10 data, 3D CNN on PPMI (neuroimaging) data, and image segmentation using HRNet on Cityscape data with and without adversarial attacks.
