Table of Contents
Fetching ...

Making Batch Normalization Great in Federated Deep Learning

Jike Zhong, Hong-You Chen, Wei-Lun Chao

TL;DR

This work reexamines the common belief that BN harms federated learning by empirically showing BN often beats GN except in extreme non-IID or high-frequency communication regimes. It identifies BN-specific issues—gradient bias from mismatched minibatch statistics and training-testing statistic misalignment—and introduces FixBN, a simple two-stage approach that preserves BN benefits while eliminating the detrimental effects without extra cost. FixBN transitions from standard BN during an initial exploration phase to using fixed global BN statistics in a calibration phase, effectively aligning training and testing normalization and enabling FedAvg to approximate centralized gradients even with frequent communication. The approach is validated across CIFAR-10, Tiny-ImageNet, ImageNet, and Cityscapes, and is complemented by maintaining SGD momentum, yielding robust improvements over BN and GN in diverse FL settings. The results provide a practical, scalable path to leverage BN in FL and motivate further theoretical analysis of BN dynamics under data heterogeneity and distributed optimization.

Abstract

Batch Normalization (BN) is widely used in {centralized} deep learning to improve convergence and generalization. However, in {federated} learning (FL) with decentralized data, prior work has observed that training with BN could hinder performance and suggested replacing it with Group Normalization (GN). In this paper, we revisit this substitution by expanding the empirical study conducted in prior work. Surprisingly, we find that BN outperforms GN in many FL settings. The exceptions are high-frequency communication and extreme non-IID regimes. We reinvestigate factors that are believed to cause this problem, including the mismatch of BN statistics across clients and the deviation of gradients during local training. We empirically identify a simple practice that could reduce the impacts of these factors while maintaining the strength of BN. Our approach, which we named FIXBN, is fairly easy to implement, without any additional training or communication costs, and performs favorably across a wide range of FL settings. We hope that our study could serve as a valuable reference for future practical usage and theoretical analysis in FL.

Making Batch Normalization Great in Federated Deep Learning

TL;DR

This work reexamines the common belief that BN harms federated learning by empirically showing BN often beats GN except in extreme non-IID or high-frequency communication regimes. It identifies BN-specific issues—gradient bias from mismatched minibatch statistics and training-testing statistic misalignment—and introduces FixBN, a simple two-stage approach that preserves BN benefits while eliminating the detrimental effects without extra cost. FixBN transitions from standard BN during an initial exploration phase to using fixed global BN statistics in a calibration phase, effectively aligning training and testing normalization and enabling FedAvg to approximate centralized gradients even with frequent communication. The approach is validated across CIFAR-10, Tiny-ImageNet, ImageNet, and Cityscapes, and is complemented by maintaining SGD momentum, yielding robust improvements over BN and GN in diverse FL settings. The results provide a practical, scalable path to leverage BN in FL and motivate further theoretical analysis of BN dynamics under data heterogeneity and distributed optimization.

Abstract

Batch Normalization (BN) is widely used in {centralized} deep learning to improve convergence and generalization. However, in {federated} learning (FL) with decentralized data, prior work has observed that training with BN could hinder performance and suggested replacing it with Group Normalization (GN). In this paper, we revisit this substitution by expanding the empirical study conducted in prior work. Surprisingly, we find that BN outperforms GN in many FL settings. The exceptions are high-frequency communication and extreme non-IID regimes. We reinvestigate factors that are believed to cause this problem, including the mismatch of BN statistics across clients and the deviation of gradients during local training. We empirically identify a simple practice that could reduce the impacts of these factors while maintaining the strength of BN. Our approach, which we named FIXBN, is fairly easy to implement, without any additional training or communication costs, and performs favorably across a wide range of FL settings. We hope that our study could serve as a valuable reference for future practical usage and theoretical analysis in FL.
Paper Structure (47 sections, 5 equations, 14 figures, 9 tables, 1 algorithm)

This paper contains 47 sections, 5 equations, 14 figures, 9 tables, 1 algorithm.

Figures (14)

  • Figure 1: Our approach, FixBN, notably bridges the gap of using BN in FL and centralized learning. X-axis: communication rounds in FL, after every local SGD step; y-axis: test accuracy on CIFAR-10 krizhevsky2009learning; $\star$: further with our SGD momentum in FL. Please see \ref{['sec:analysis']} and \ref{['s_mom']} for details.
  • Figure 2: Is GN consistently better than BN in FL? No. We compare their test accuracy in various FL settings on CIAFR-10 and Tiny-ImageNet, including different non-IID partitions and numbers of local steps $E$. We consider (a) a fixed budget of the total number of SGD steps (e.g., for CIFAR-10, $20\text{ local steps } \times 20\text{ batch size }\times 5\text{ clients }\times 3200\text{ rounds} = 128\text{ epochs of CIFAR-10 training data}$) or (b) a fixed number of total communication rounds ($128$ rounds). Green cells: BN outperforms GN. Purple cells: GN outperforms BN.
  • Figure 3: Training dynamics of FedAvg with BN. (a) Changes of global accumulated statistics ($\|\bar{\boldsymbol{S}}^{(t+1)}-\bar{\boldsymbol{S}}^{(t)}\|_1$) and deviations of local mini-batch statistics from global accumulated statistics ($\|{\boldsymbol{S}_{m, \mathcal{B}}}^{(t+1)}-\bar{\boldsymbol{S}}^{(t)}\|_1$). (b) Variances (running over $t-200$ to $t$) of local mini-batch statistics $\boldsymbol{S}_{m, \mathcal{B}}^{(t)}$. (c) Averaged local training losses over clients. (d) Final-round accuracy when freezing BN statistics at different intermediate rounds in a non-IID CIFAR-10 setting (shards, $E$ = 100) using ResNet20. See the main text for details.
  • Figure 4: Non-IID partitions with $E=100$ steps.
  • Figure 5: Maintained momentum results. We apply maintained global ($^\dagger$) and local momentum ($^\star$) to FedAvg with different normalizers. The setting is (Shards, fixed 128 epochs) with different local steps $E$. We show results on CIFAR-10. Results on Tiny-ImageNet follow similar trends (see \ref{['fig:fix_mom_tiny']} in appendix).
  • ...and 9 more figures