Table of Contents
Fetching ...

Federated Learning with Matched Averaging

Hongyi Wang, Mikhail Yurochkin, Yuekai Sun, Dimitris Papailiopoulos, Yasaman Khazaeni

TL;DR

Federated learning enables on-device data training without centralization, but naive weight averaging fails due to permutation invariance and data heterogeneity. FedMA introduces a layer-wise, permutation-aware matching-and-averaging scheme for CNNs and LSTMs, solved via a BBP-MAP-based approach with Hungarian matching to align hidden elements before averaging and to adapt the global model size. Empirically, FedMA outperforms FedAvg and FedProx in deep CNNs and LSTMs while reducing communication overhead, and it remains robust to data biases and increasing client heterogeneity. The method also yields more interpretable global filters, demonstrating practical advantages for federated learning in real-world, heterogeneous settings.

Abstract

Federated learning allows edge devices to collaboratively learn a shared model while keeping the training data on device, decoupling the ability to do model training from the need to store the data in the cloud. We propose Federated matched averaging (FedMA) algorithm designed for federated learning of modern neural network architectures e.g. convolutional neural networks (CNNs) and LSTMs. FedMA constructs the shared global model in a layer-wise manner by matching and averaging hidden elements (i.e. channels for convolution layers; hidden states for LSTM; neurons for fully connected layers) with similar feature extraction signatures. Our experiments indicate that FedMA not only outperforms popular state-of-the-art federated learning algorithms on deep CNN and LSTM architectures trained on real world datasets, but also reduces the overall communication burden.

Federated Learning with Matched Averaging

TL;DR

Federated learning enables on-device data training without centralization, but naive weight averaging fails due to permutation invariance and data heterogeneity. FedMA introduces a layer-wise, permutation-aware matching-and-averaging scheme for CNNs and LSTMs, solved via a BBP-MAP-based approach with Hungarian matching to align hidden elements before averaging and to adapt the global model size. Empirically, FedMA outperforms FedAvg and FedProx in deep CNNs and LSTMs while reducing communication overhead, and it remains robust to data biases and increasing client heterogeneity. The method also yields more interpretable global filters, demonstrating practical advantages for federated learning in real-world, heterogeneous settings.

Abstract

Federated learning allows edge devices to collaboratively learn a shared model while keeping the training data on device, decoupling the ability to do model training from the need to store the data in the cloud. We propose Federated matched averaging (FedMA) algorithm designed for federated learning of modern neural network architectures e.g. convolutional neural networks (CNNs) and LSTMs. FedMA constructs the shared global model in a layer-wise manner by matching and averaging hidden elements (i.e. channels for convolution layers; hidden states for LSTM; neurons for fully connected layers) with similar feature extraction signatures. Our experiments indicate that FedMA not only outperforms popular state-of-the-art federated learning algorithms on deep CNN and LSTM architectures trained on real world datasets, but also reduces the overall communication burden.

Paper Structure

This paper contains 28 sections, 6 equations, 6 figures, 7 tables, 1 algorithm.

Figures (6)

  • Figure 1: Comparison among various federated learning methods with limited number of communications on LeNet trained on MNIST; VGG-9 trained on CIFAR-10 dataset; LSTM trained on Shakespeare dataset over: (a) homogeneous data partition (b) heterogeneous data partition.
  • Figure 2: Convergence rates of various methods in two federated learning scenarios: training VGG-9 on CIFAR-10 with $J=16$ clients and training LSTM on Shakespeare dataset with $J=66$ clients.
  • Figure 3: The effect of number of local training epochs on various methods.
  • Figure 4: Performance on skewed CIFAR-10 dataset.
  • Figure 5: Representations generated by the first convolution layers of locally trained models, FedMA global model and the FedAvg global model.
  • ...and 1 more figures