Table of Contents
Fetching ...

Towards Understanding Regularization in Batch Normalization

Ping Luo, Xinjiang Wang, Wenqi Shao, Zhanglin Peng

TL;DR

This work provides a theoretical framework for understanding Batch Normalization by modeling BN as an explicit regularizer comprising population normalization (PN) and gamma decay, derived from priors on batch statistics. Using a single-layer perceptron as a building block and a teacher-student dynamic, the authors derive ODEs to describe learning dynamics, show BN enables larger maximum and effective learning rates, and analyze generalization via a statistical-mechanics approach. They also validate the theory with CNN experiments, showing BN-like regularization traits match PN+gamma decay under appropriate conditions and that regularization can be maintained or enhanced with dropout at large batch sizes. The results unify optimization and generalization insights for BN and provide directions for extending the analysis to deeper networks and other normalizers.

Abstract

Batch Normalization (BN) improves both convergence and generalization in training neural networks. This work understands these phenomena theoretically. We analyze BN by using a basic block of neural networks, consisting of a kernel layer, a BN layer, and a nonlinear activation function. This basic network helps us understand the impacts of BN in three aspects. First, by viewing BN as an implicit regularizer, BN can be decomposed into population normalization (PN) and gamma decay as an explicit regularization. Second, learning dynamics of BN and the regularization show that training converged with large maximum and effective learning rate. Third, generalization of BN is explored by using statistical mechanics. Experiments demonstrate that BN in convolutional neural networks share the same traits of regularization as the above analyses.

Towards Understanding Regularization in Batch Normalization

TL;DR

This work provides a theoretical framework for understanding Batch Normalization by modeling BN as an explicit regularizer comprising population normalization (PN) and gamma decay, derived from priors on batch statistics. Using a single-layer perceptron as a building block and a teacher-student dynamic, the authors derive ODEs to describe learning dynamics, show BN enables larger maximum and effective learning rates, and analyze generalization via a statistical-mechanics approach. They also validate the theory with CNN experiments, showing BN-like regularization traits match PN+gamma decay under appropriate conditions and that regularization can be maintained or enhanced with dropout at large batch sizes. The results unify optimization and generalization insights for BN and provide directions for extending the analysis to deeper networks and other normalizers.

Abstract

Batch Normalization (BN) improves both convergence and generalization in training neural networks. This work understands these phenomena theoretically. We analyze BN by using a basic block of neural networks, consisting of a kernel layer, a BN layer, and a nonlinear activation function. This basic network helps us understand the impacts of BN in three aspects. First, by viewing BN as an implicit regularizer, BN can be decomposed into population normalization (PN) and gamma decay as an explicit regularization. Second, learning dynamics of BN and the regularization show that training converged with large maximum and effective learning rate. Third, generalization of BN is explored by using statistical mechanics. Experiments demonstrate that BN in convolutional neural networks share the same traits of regularization as the above analyses.

Paper Structure

This paper contains 28 sections, 5 theorems, 49 equations, 4 figures, 2 tables.

Key Result

Proposition 1

Let $(Q_0,R_0,L_0)$ denote a fixed point with parameters $Q$, $R$ and $L$ of Eqn.(eq:QRL). Assume the learning rate $\eta$ is sufficiently small when training converges and $x\sim\mathcal{N}(0,\frac{1}{N}\mathbf{I})$. If activation function $g$ is $\mathrm{ReLU}$, then we have $Q_0=\frac{1}{2\zeta+1

Figures (4)

  • Figure 1: (a) shows generalization error v.s. effective load $\alpha$ using a linear student (identity units). 'WN+gamma decay' has two curves $\zeta=\frac{1}{2M}$ and $\zeta=0.25$. BN is trained with $M=32$. (b) shows generalization error v.s. effective load $\alpha$ using a ReLU student. 'WN+gamma decay' has $\zeta=\frac{1}{4M}$ and is compared to BN with batch size $M=32$. The theoretical curve for vanilla SGD is also shown in blue. The red line is the generalization error of vanilla SGD with no noise in the teacher and thus serves as a lower bound.
  • Figure 2: (a) & (b) compare the loss (both training and evaluation) and validation accuracy between BN and PN on CIFAR10 using a ResNet18 network; (c) & (d) compare the training and validation loss curve with WN + mean-only BN and WN + variance-only BN; (e) & (f) validate the regularization effect of BN on both $\gamma^2$ and the validation loss with different batch sizes; (g) & (h) show the loss and top-1 validation accuracy of ResNet18 with additional regularization (dropout) on large-batch training of BN and WN.
  • Figure 3: Results of downsampled ImageNet. (a) plots training and evaluation loss. (b) shows validation accuracy. The models are trained on 8 GPUs.
  • Figure 4: Study of parameter norm. Vanilla SGD is finetuned from a network pretrained by BN on CIFAR10. The first four figures show the magnitude of the kernel parameters in different layers in finetuning, compared to the effective norm of BN defined as $\gamma\frac{\|{ \mathbf{w} }\|}{\sigma_\mathcal{B}}$. The last two figures compare the training and validation losses in finetuning.

Theorems & Definitions (11)

  • proof
  • Proposition 1
  • proof
  • Proposition 2
  • proof
  • Proposition 3
  • proof
  • Proposition 4
  • proof
  • Proposition 5
  • ...and 1 more