Table of Contents
Fetching ...

Understanding the Role of Layer Normalization in Label-Skewed Federated Learning

Guojun Zhang, Mahdi Beitollahi, Alex Bie, Xi Chen

TL;DR

We study layer normalization in federated learning under extreme label skew, revealing that feature normalization (FN) is the key mechanism by which LN improves convergence and robustness. By showing that, in scale-equivariant networks, LN/FN largely reduce to last-layer scaling, we explain how LN combats feature collapse and local overfitting on skewed clients. The work provides extensive empirical benchmarks across CNN/ResNet architectures and datasets (CIFAR-10/100, TinyImageNet, PACS), along with ablations demonstrating FN’s essential role and LN’s limited impact on expressive power. The results suggest practical benefits for FL systems facing severe label distribution heterogeneity and offer a theory-grounded direction for future normalization-based improvements and cross-domain validation.

Abstract

Layer normalization (LN) is a widely adopted deep learning technique especially in the era of foundation models. Recently, LN has been shown to be surprisingly effective in federated learning (FL) with non-i.i.d. data. However, exactly why and how it works remains mysterious. In this work, we reveal the profound connection between layer normalization and the label shift problem in federated learning. To understand layer normalization better in FL, we identify the key contributing mechanism of normalization methods in FL, called feature normalization (FN), which applies normalization to the latent feature representation before the classifier head. Although LN and FN do not improve expressive power, they control feature collapse and local overfitting to heavily skewed datasets, and thus accelerates global training. Empirically, we show that normalization leads to drastic improvements on standard benchmarks under extreme label shift. Moreover, we conduct extensive ablation studies to understand the critical factors of layer normalization in FL. Our results verify that FN is an essential ingredient inside LN to significantly improve the convergence of FL while remaining robust to learning rate choices, especially under extreme label shift where each client has access to few classes. Our code is available at \url{https://github.com/huawei-noah/Federated-Learning/tree/main/Layer_Normalization}.

Understanding the Role of Layer Normalization in Label-Skewed Federated Learning

TL;DR

We study layer normalization in federated learning under extreme label skew, revealing that feature normalization (FN) is the key mechanism by which LN improves convergence and robustness. By showing that, in scale-equivariant networks, LN/FN largely reduce to last-layer scaling, we explain how LN combats feature collapse and local overfitting on skewed clients. The work provides extensive empirical benchmarks across CNN/ResNet architectures and datasets (CIFAR-10/100, TinyImageNet, PACS), along with ablations demonstrating FN’s essential role and LN’s limited impact on expressive power. The results suggest practical benefits for FL systems facing severe label distribution heterogeneity and offer a theory-grounded direction for future normalization-based improvements and cross-domain validation.

Abstract

Layer normalization (LN) is a widely adopted deep learning technique especially in the era of foundation models. Recently, LN has been shown to be surprisingly effective in federated learning (FL) with non-i.i.d. data. However, exactly why and how it works remains mysterious. In this work, we reveal the profound connection between layer normalization and the label shift problem in federated learning. To understand layer normalization better in FL, we identify the key contributing mechanism of normalization methods in FL, called feature normalization (FN), which applies normalization to the latent feature representation before the classifier head. Although LN and FN do not improve expressive power, they control feature collapse and local overfitting to heavily skewed datasets, and thus accelerates global training. Empirically, we show that normalization leads to drastic improvements on standard benchmarks under extreme label shift. Moreover, we conduct extensive ablation studies to understand the critical factors of layer normalization in FL. Our results verify that FN is an essential ingredient inside LN to significantly improve the convergence of FL while remaining robust to learning rate choices, especially under extreme label shift where each client has access to few classes. Our code is available at \url{https://github.com/huawei-noah/Federated-Learning/tree/main/Layer_Normalization}.
Paper Structure (42 sections, 13 theorems, 43 equations, 9 figures, 14 tables)

This paper contains 42 sections, 13 theorems, 43 equations, 9 figures, 14 tables.

Key Result

Proposition 1

Under assmp:scale_equiv, scale normalizing each layer is equivalent to only scale normalizing the last layer. That is, for all affine transformations${\bm{A}}_1,\dots,{\bm{A}}_L$ the function is equal to if all the intermediate hidden vectors after activation are non-zero.

Figures (9)

  • Figure 1: Visualization of label shift.
  • Figure 2: Local overfitting in the one-class setting on CIFAR-10. The client only has examples from class 0. The blue lines show the average global performance. ( left) the test accuracies of the pre-trained model before local training; ( middle) after 5 steps of local training with a vanilla model; ( right) after 5 steps of local training with FN. Best viewed in color.
  • Figure 3: Local training with only samples from one class. (left): vector norms of each class embedding. (middle left): the norms of different feature vectors. We randomly choose 20 images from the dataset. (middle right): singular values of features of a local overfitted model; (right): singular values of normalized features learned from FedAvg and FedFN.
  • Figure 4: Comparing test accuracies of centralized training with FedAvg of models with different normalizations in one-class label shift in CIFAR-10. The $x$-axis denotes how much data is fed into an algorithm, measured by batches.
  • Figure 5: Effect of data heterogeneity on the performance of normalization. $\beta$ represents the parameter in the Dirichlet distribution that is used to sample client label distributions.
  • ...and 4 more figures

Theorems & Definitions (20)

  • Proposition 1: reduced feature normalization
  • Proposition 2: reduced layer normalization
  • Proposition 3: expressive power
  • Theorem 1: divergent norms
  • Lemma 1: scale equivariant activation
  • proof
  • Proposition 3: reduced feature normalization
  • Proposition 1': reduced feature normalization
  • proof
  • Proposition 3: reduced layer normalization
  • ...and 10 more