Table of Contents
Fetching ...

Why Go Full? Elevating Federated Learning Through Partial Network Updates

Haolin Wang, Xuefeng Liu, Jianwei Niu, Wenkai Guo, Shaojie Tang

TL;DR

The FedPart method is introduced, which restricts model updates to either a single layer or a few layers during each communication round, and significantly surpasses conventional full network update strategies in terms of convergence speed and accuracy, while also reducing communication and computational overheads.

Abstract

Federated learning is a distributed machine learning paradigm designed to protect user data privacy, which has been successfully implemented across various scenarios. In traditional federated learning, the entire parameter set of local models is updated and averaged in each training round. Although this full network update method maximizes knowledge acquisition and sharing for each model layer, it prevents the layers of the global model from cooperating effectively to complete the tasks of each client, a challenge we refer to as layer mismatch. This mismatch problem recurs after every parameter averaging, consequently slowing down model convergence and degrading overall performance. To address the layer mismatch issue, we introduce the FedPart method, which restricts model updates to either a single layer or a few layers during each communication round. Furthermore, to maintain the efficiency of knowledge acquisition and sharing, we develop several strategies to select trainable layers in each round, including sequential updating and multi-round cycle training. Through both theoretical analysis and experiments, our findings demonstrate that the FedPart method significantly surpasses conventional full network update strategies in terms of convergence speed and accuracy, while also reducing communication and computational overheads.

Why Go Full? Elevating Federated Learning Through Partial Network Updates

TL;DR

The FedPart method is introduced, which restricts model updates to either a single layer or a few layers during each communication round, and significantly surpasses conventional full network update strategies in terms of convergence speed and accuracy, while also reducing communication and computational overheads.

Abstract

Federated learning is a distributed machine learning paradigm designed to protect user data privacy, which has been successfully implemented across various scenarios. In traditional federated learning, the entire parameter set of local models is updated and averaged in each training round. Although this full network update method maximizes knowledge acquisition and sharing for each model layer, it prevents the layers of the global model from cooperating effectively to complete the tasks of each client, a challenge we refer to as layer mismatch. This mismatch problem recurs after every parameter averaging, consequently slowing down model convergence and degrading overall performance. To address the layer mismatch issue, we introduce the FedPart method, which restricts model updates to either a single layer or a few layers during each communication round. Furthermore, to maintain the efficiency of knowledge acquisition and sharing, we develop several strategies to select trainable layers in each round, including sequential updating and multi-round cycle training. Through both theoretical analysis and experiments, our findings demonstrate that the FedPart method significantly surpasses conventional full network update strategies in terms of convergence speed and accuracy, while also reducing communication and computational overheads.

Paper Structure

This paper contains 23 sections, 24 equations, 8 figures, 13 tables.

Figures (8)

  • Figure 1: Update step sizes for each iteration. The experiment uses the ResNet-8 model with 20,000 CIFAR-100 images distributed in an i.i.d. manner across 40 clients.
  • Figure 2: Mechanism for layer mismatch in FedAvg and FedPart.
  • Figure 3: Strategy for selecting trainable layers.
  • Figure 4: Model architecture and layer partitioning about our ResNet-8 and ResNet-18 model.
  • Figure 5: Model architecture and layer partitioning for language transformer.
  • ...and 3 more figures