Table of Contents
Fetching ...

Understanding and Improving Model Averaging in Federated Learning on Heterogeneous Data

Tailin Zhou, Zehong Lin, Jun Zhang, Danny H. K. Tsang

TL;DR

Federated learning with heterogeneous data exhibits strong empirical performance from model averaging, but the mechanics were not well understood. The authors visualize loss landscapes to reveal that the global model often lies within a basin formed by client models yet can deviate from the basin center, and they decompose the global loss into five factors (TrainBias, HeterBias, Var, Cov, Locality) to explain this behavior. They connect FMA to a weighted ensemble of outputs (WENS) and propose Iterative Moving Averaging (IMA) with mild client exploration to keep the global model near the basin center during late training, improving accuracy and speed across benchmarks. Across diverse heterogeneous setups, IMA yields consistent gains and reduced communication overhead, providing a geometry-informed, practical enhancement to FL in non-iid environments.

Abstract

Model averaging is a widely adopted technique in federated learning (FL) that aggregates multiple client models to obtain a global model. Remarkably, model averaging in FL yields a superior global model, even when client models are trained with non-convex objective functions and on heterogeneous local datasets. However, the rationale behind its success remains poorly understood. To shed light on this issue, we first visualize the loss landscape of FL over client and global models to illustrate their geometric properties. The visualization shows that the client models encompass the global model within a common basin, and interestingly, the global model may deviate from the basin's center while still outperforming the client models. To gain further insights into model averaging in FL, we decompose the expected loss of the global model into five factors related to the client models. Specifically, our analysis reveals that the global model loss after early training mainly arises from \textit{i)} the client model's loss on non-overlapping data between client datasets and the global dataset and \textit{ii)} the maximum distance between the global and client models. Based on the findings from our loss landscape visualization and loss decomposition, we propose utilizing iterative moving averaging (IMA) on the global model at the late training phase to reduce its deviation from the expected minimum, while constraining client exploration to limit the maximum distance between the global and client models. Our experiments demonstrate that incorporating IMA into existing FL methods significantly improves their accuracy and training speed on various heterogeneous data setups of benchmark datasets. Code is available at \url{https://github.com/TailinZhou/FedIMA}.

Understanding and Improving Model Averaging in Federated Learning on Heterogeneous Data

TL;DR

Federated learning with heterogeneous data exhibits strong empirical performance from model averaging, but the mechanics were not well understood. The authors visualize loss landscapes to reveal that the global model often lies within a basin formed by client models yet can deviate from the basin center, and they decompose the global loss into five factors (TrainBias, HeterBias, Var, Cov, Locality) to explain this behavior. They connect FMA to a weighted ensemble of outputs (WENS) and propose Iterative Moving Averaging (IMA) with mild client exploration to keep the global model near the basin center during late training, improving accuracy and speed across benchmarks. Across diverse heterogeneous setups, IMA yields consistent gains and reduced communication overhead, providing a geometry-informed, practical enhancement to FL in non-iid environments.

Abstract

Model averaging is a widely adopted technique in federated learning (FL) that aggregates multiple client models to obtain a global model. Remarkably, model averaging in FL yields a superior global model, even when client models are trained with non-convex objective functions and on heterogeneous local datasets. However, the rationale behind its success remains poorly understood. To shed light on this issue, we first visualize the loss landscape of FL over client and global models to illustrate their geometric properties. The visualization shows that the client models encompass the global model within a common basin, and interestingly, the global model may deviate from the basin's center while still outperforming the client models. To gain further insights into model averaging in FL, we decompose the expected loss of the global model into five factors related to the client models. Specifically, our analysis reveals that the global model loss after early training mainly arises from \textit{i)} the client model's loss on non-overlapping data between client datasets and the global dataset and \textit{ii)} the maximum distance between the global and client models. Based on the findings from our loss landscape visualization and loss decomposition, we propose utilizing iterative moving averaging (IMA) on the global model at the late training phase to reduce its deviation from the expected minimum, while constraining client exploration to limit the maximum distance between the global and client models. Our experiments demonstrate that incorporating IMA into existing FL methods significantly improves their accuracy and training speed on various heterogeneous data setups of benchmark datasets. Code is available at \url{https://github.com/TailinZhou/FedIMA}.
Paper Structure (44 sections, 5 theorems, 17 equations, 10 figures, 7 tables, 1 algorithm)

This paper contains 44 sections, 5 theorems, 17 equations, 10 figures, 7 tables, 1 algorithm.

Key Result

Lemma 1

(FMA and WENS. See proof in Appendix) Given $K$ models $\mathbf{\{w}_k\}_{k=1}^K$ and $n_i/n_j \neq \infty$ when $i\neq j$, we denote $\Delta_k =\|\mathbf{w}_{k} -\mathbf{w}_{\rm FMA} \|$ and $\Delta=\max_{k=1}^K\Delta_k$. Then, we have: where the WENS on the $K$ models is to conduct weighted averaging on the outputs of these models when given the same input, represented as $f_{\rm WENS}(x)= \su

Figures (10)

  • Figure 1: Visualization of the loss (top row) and classification error (bottom row) landscapes on the CIFAR-10 test dataset, along with three client models from the early stage (first column), middle stage (second column), and final stage (third column), as well as the visualization of three global models from the final three rounds (fourth column). The black triangles represent the location of three models in the plane, while the white cross represents their average model's location. The loss/error landscape can be viewed as a basin, where client models reach the basin's wall and the global model approaches the basin's bottom as FL training proceeds. FMA helps move the global model towards the basin's bottom by averaging client models on the basin's wall, while heterogeneous data deviates the global model from the basin's center.
  • Figure 2: Train and heterogeneous biases w.r.t rounds (x-axis).
  • Figure 3: Locality (L2 distance) w.r.t rounds (x-axis).
  • Figure 4: Test error w.r.t model number and model similarity.
  • Figure 5: (a) A toy example of 1D loss landscape visualization to show the motivation of IMA; (b) interpolation among global models to validate the effect of interpolation on alleviating global models' deviation from the basin's center; (c) interpolation between the IMA and global models to indicate the flatness of the IMA model near the basin's center.
  • ...and 5 more figures

Theorems & Definitions (5)

  • Lemma 1
  • Theorem 1
  • Theorem 2
  • Corollary 1
  • Corollary 2