Table of Contents
Fetching ...

BN-SCAFFOLD: controlling the drift of Batch Normalization statistics in Federated Learning

Gonzalo Iñaki Quintana, Laurence Vancamberg, Vincent Jugnon, Mathilde Mougeot, Agnès Desolneux

TL;DR

Federated Learning (FL) with Batch Normalization (BN) faces performance degradation under data heterogeneity due to BN statistics drift across clients. The authors propose BN-SCAFFOLD, a variance-reduction method that extends SCAFFOLD to BN statistics, and develop a unified convergence framework for BN-DNNs in FL. They prove convergence guarantees and show experimentally on MNIST and CIFAR-10 that BN-SCAFFOLD achieves FedTAN-like performance without FedTAN’s depth-dependent communication overhead, outperforming FedAvg, SCAFFOLD, and other BN-heterogeneity approaches. The approach also extends to other normalization strategies and offers a practical, communication-efficient solution for BN-enabled FL in cross-silo settings, with potential privacy and scalability benefits.

Abstract

Federated Learning (FL) is gaining traction as a learning paradigm for training Machine Learning (ML) models in a decentralized way. Batch Normalization (BN) is ubiquitous in Deep Neural Networks (DNN), as it improves convergence and generalization. However, BN has been reported to hinder performance of DNNs in heterogeneous FL. Recently, the FedTAN algorithm has been proposed to mitigate the effect of heterogeneity on BN, by aggregating BN statistics and gradients from all the clients. However, it has a high communication cost, that increases linearly with the depth of the DNN. SCAFFOLD is a variance reduction algorithm, that estimates and corrects the client drift in a communication-efficient manner. Despite its promising results in heterogeneous FL settings, it has been reported to underperform for models with BN. In this work, we seek to revive SCAFFOLD, and more generally variance reduction, as an efficient way of training DNN with BN in heterogeneous FL. We introduce a unified theoretical framework for analyzing the convergence of variance reduction algorithms in the BN-DNN setting, inspired of by the work of Wang et al. 2023, and show that SCAFFOLD is unable to remove the bias introduced by BN. We thus propose the BN-SCAFFOLD algorithm, which extends the client drift correction of SCAFFOLD to BN statistics. We prove convergence using the aforementioned framework and validate the theoretical results with experiments on MNIST and CIFAR-10. BN-SCAFFOLD equals the performance of FedTAN, without its high communication cost, outperforming Federated Averaging (FedAvg), SCAFFOLD, and other FL algorithms designed to mitigate BN heterogeneity.

BN-SCAFFOLD: controlling the drift of Batch Normalization statistics in Federated Learning

TL;DR

Federated Learning (FL) with Batch Normalization (BN) faces performance degradation under data heterogeneity due to BN statistics drift across clients. The authors propose BN-SCAFFOLD, a variance-reduction method that extends SCAFFOLD to BN statistics, and develop a unified convergence framework for BN-DNNs in FL. They prove convergence guarantees and show experimentally on MNIST and CIFAR-10 that BN-SCAFFOLD achieves FedTAN-like performance without FedTAN’s depth-dependent communication overhead, outperforming FedAvg, SCAFFOLD, and other BN-heterogeneity approaches. The approach also extends to other normalization strategies and offers a practical, communication-efficient solution for BN-enabled FL in cross-silo settings, with potential privacy and scalability benefits.

Abstract

Federated Learning (FL) is gaining traction as a learning paradigm for training Machine Learning (ML) models in a decentralized way. Batch Normalization (BN) is ubiquitous in Deep Neural Networks (DNN), as it improves convergence and generalization. However, BN has been reported to hinder performance of DNNs in heterogeneous FL. Recently, the FedTAN algorithm has been proposed to mitigate the effect of heterogeneity on BN, by aggregating BN statistics and gradients from all the clients. However, it has a high communication cost, that increases linearly with the depth of the DNN. SCAFFOLD is a variance reduction algorithm, that estimates and corrects the client drift in a communication-efficient manner. Despite its promising results in heterogeneous FL settings, it has been reported to underperform for models with BN. In this work, we seek to revive SCAFFOLD, and more generally variance reduction, as an efficient way of training DNN with BN in heterogeneous FL. We introduce a unified theoretical framework for analyzing the convergence of variance reduction algorithms in the BN-DNN setting, inspired of by the work of Wang et al. 2023, and show that SCAFFOLD is unable to remove the bias introduced by BN. We thus propose the BN-SCAFFOLD algorithm, which extends the client drift correction of SCAFFOLD to BN statistics. We prove convergence using the aforementioned framework and validate the theoretical results with experiments on MNIST and CIFAR-10. BN-SCAFFOLD equals the performance of FedTAN, without its high communication cost, outperforming Federated Averaging (FedAvg), SCAFFOLD, and other FL algorithms designed to mitigate BN heterogeneity.
Paper Structure (40 sections, 29 theorems, 176 equations, 3 figures, 11 tables, 1 algorithm)

This paper contains 40 sections, 29 theorems, 176 equations, 3 figures, 11 tables, 1 algorithm.

Key Result

Theorem 3.1

The convergence rate of the family of algorithms defined in Definition def:variance_reduc_algos for a general non-convex objective function $F$, and under the assumptions discussed in Section sec:assumptions_conv_analysis, is given by where $\mathcal{T} = \mathcal{T}_1 + \mathcal{T}_2 + \mathcal{T}_3 + \mathcal{T}_4 + \mathcal{T}_5 - \mathcal{T}_6$ with and with $\delta_{E, \gamma, L} \coloneq

Figures (3)

  • Figure 1: Comparison of FedAvg, SCAFFOLD, and BN-SCAFFOLD in MNIST with $N$=2 clients.
  • Figure 2: Classification performance with 95% CI for different degrees of heterogeneity, using original x-ray image.
  • Figure 3: Classification performance with 95% CI for different degrees of heterogeneity, using synthetic images.

Theorems & Definitions (57)

  • Definition 1
  • Theorem 3.1
  • Lemma E.1: Jensen's inequality
  • Corollary E.0.1
  • proof
  • Lemma E.2: Relaxed triangle inequality
  • proof
  • Lemma E.3
  • proof
  • Lemma E.4: Separating mean and variance
  • ...and 47 more