Table of Contents
Fetching ...

Hierarchical Federated Learning with Multi-Timescale Gradient Correction

Wenzhi Fang, Dong-Jun Han, Evan Chen, Shiqiang Wang, Christopher G. Brinton

TL;DR

The paper tackles multi-level gradient drift in hierarchical federated learning (HFL) by introducing MTGC, a two-timescale gradient-correction framework. MTGC simultaneously applies client-group corrections to align individual clients with their group gradients and group-global corrections to align group gradients with the global objective, with updates occurring after group and global aggregations respectively. The authors establish non-convex convergence guarantees and show linear speedups in the number of local updates, group aggregations, and clients, with a bound that remains stable under varying data heterogeneity. Empirical results on EMNIST, Fashion-MNIST, CIFAR-10, and CIFAR-100 demonstrate that MTGC outperforms standard HFedAvg variants and gradient-correction baselines across a range of non-i.i.d. settings, and scale to multi-level hierarchies, underscoring its practical impact for large-scale, privacy-preserving distributed learning.

Abstract

While traditional federated learning (FL) typically focuses on a star topology where clients are directly connected to a central server, real-world distributed systems often exhibit hierarchical architectures. Hierarchical FL (HFL) has emerged as a promising solution to bridge this gap, leveraging aggregation points at multiple levels of the system. However, existing algorithms for HFL encounter challenges in dealing with multi-timescale model drift, i.e., model drift occurring across hierarchical levels of data heterogeneity. In this paper, we propose a multi-timescale gradient correction (MTGC) methodology to resolve this issue. Our key idea is to introduce distinct control variables to (i) correct the client gradient towards the group gradient, i.e., to reduce client model drift caused by local updates based on individual datasets, and (ii) correct the group gradient towards the global gradient, i.e., to reduce group model drift caused by FL over clients within the group. We analytically characterize the convergence behavior of MTGC under general non-convex settings, overcoming challenges associated with couplings between correction terms. We show that our convergence bound is immune to the extent of data heterogeneity, confirming the stability of the proposed algorithm against multi-level non-i.i.d. data. Through extensive experiments on various datasets and models, we validate the effectiveness of MTGC in diverse HFL settings. The code for this project is available at \href{https://github.com/wenzhifang/MTGC}{https://github.com/wenzhifang/MTGC}.

Hierarchical Federated Learning with Multi-Timescale Gradient Correction

TL;DR

The paper tackles multi-level gradient drift in hierarchical federated learning (HFL) by introducing MTGC, a two-timescale gradient-correction framework. MTGC simultaneously applies client-group corrections to align individual clients with their group gradients and group-global corrections to align group gradients with the global objective, with updates occurring after group and global aggregations respectively. The authors establish non-convex convergence guarantees and show linear speedups in the number of local updates, group aggregations, and clients, with a bound that remains stable under varying data heterogeneity. Empirical results on EMNIST, Fashion-MNIST, CIFAR-10, and CIFAR-100 demonstrate that MTGC outperforms standard HFedAvg variants and gradient-correction baselines across a range of non-i.i.d. settings, and scale to multi-level hierarchies, underscoring its practical impact for large-scale, privacy-preserving distributed learning.

Abstract

While traditional federated learning (FL) typically focuses on a star topology where clients are directly connected to a central server, real-world distributed systems often exhibit hierarchical architectures. Hierarchical FL (HFL) has emerged as a promising solution to bridge this gap, leveraging aggregation points at multiple levels of the system. However, existing algorithms for HFL encounter challenges in dealing with multi-timescale model drift, i.e., model drift occurring across hierarchical levels of data heterogeneity. In this paper, we propose a multi-timescale gradient correction (MTGC) methodology to resolve this issue. Our key idea is to introduce distinct control variables to (i) correct the client gradient towards the group gradient, i.e., to reduce client model drift caused by local updates based on individual datasets, and (ii) correct the group gradient towards the global gradient, i.e., to reduce group model drift caused by FL over clients within the group. We analytically characterize the convergence behavior of MTGC under general non-convex settings, overcoming challenges associated with couplings between correction terms. We show that our convergence bound is immune to the extent of data heterogeneity, confirming the stability of the proposed algorithm against multi-level non-i.i.d. data. Through extensive experiments on various datasets and models, we validate the effectiveness of MTGC in diverse HFL settings. The code for this project is available at \href{https://github.com/wenzhifang/MTGC}{https://github.com/wenzhifang/MTGC}.
Paper Structure (33 sections, 14 theorems, 86 equations, 11 figures, 1 table, 2 algorithms)

This paper contains 33 sections, 14 theorems, 86 equations, 11 figures, 1 table, 2 algorithms.

Key Result

Theorem 4.1

Suppose Assumptions assump_smoothness and assump_randomness_sgd hold and the learning rate satisfies $\gamma \leq \frac{1}{40 E H L}$. Then the iterates $\{ \hat{\bm x}^{t,e}\}$ obtained by the MTGC algorithm satisfy where $\tilde{N} = \left(\frac{1}{N^2} \! \sum_{j=1}^{N} \! \frac{1}{n_j} \right)^{-1}$, and $f^*$ is the lower bound of $f(\bm x)$, i.e., $f(\bm x) \geq f^*$.

Figures (11)

  • Figure 1: Illustration of multi-timescale gradient correction (MTGC) for multi-level non-i.i.d. in HFL.
  • Figure 2: Visualization of the local update process using multi-timescale gradient correction (MTGC) with $4$ clients and $2$ groups. (a) Without any gradient correction (e.g., hierarchical FedAvg), each client model moves towards its respective optimal point, denoted by $\bm x_i^{*}$. (b) When only client-group correction term $\bm z_i$ is applied, the model of client $i\in \mathcal{C}_j$ moves towards the group optimum $\bar{\bm x}_j^{*}$. (c) In MTGC, the gradient of client $i \in \mathcal{C}_j$ is adjusted by both the client-group correction term $\bm{z}_i$ and the group-global correction variable $\bm{y}_j$, assisting each client model to converge towards the global optimum $\bm{x}^{*}$ during local iterations.
  • Figure 3: Comparison with FL baselines. In this experiment, popular FL algorithms are extended to the HFL setup for comparison with MTGC. We consider four datasets in the group non-i.i.d. & client non-i.i.d. setting. Experiments are conducted over $3$ random trials. We see that MTGC obtains the best testing accuracy in each case, validating our multi-level approach for correcting multi-timescale model drifts.
  • Figure 4: Comparison with gradient correction baselines. Three different data distribution scenarios are considered. We see that the local correction method is effective for handling client non-i.i.d. within each group (top row), while the group correction method is effective for handling non-i.i.d. across groups (middle row). MTGC obtains the most stable performance (all rows) by combining multiple correction levels.
  • Figure 5: Comparison of testing accuracy versus global communication round across different system parameters under both group non-i.i.d. and client non-i.i.d. setup. $E$ and $H$ are set to $30$ and $20$, respectively.
  • ...and 6 more figures

Theorems & Definitions (14)

  • Theorem 4.1
  • Corollary 4.1
  • Lemma F.1.1
  • Lemma F.1.2
  • Lemma F.1.3
  • Lemma F.1.4
  • Lemma F.1.5
  • Lemma F.2.1
  • Lemma F.2.2
  • Lemma F.2.3
  • ...and 4 more