Table of Contents
Fetching ...

FedAdamW: A Communication-Efficient Optimizer with Convergence and Generalization Guarantees for Federated Large Models

Junkang Liu, Fanhua Shang, Kewen Zhu, Hongying Liu, Yuanyuan Liu, Jin Liu

TL;DR

FedAdamW introduces a communication-efficient federated optimizer for training large models by combining a local correction mechanism, decoupled weight decay, and a block-wise second-moment aggregation strategy. The approach mitigates client drift, reduces variance in adaptive statistics, and leverages Hessian-informed partitioning to limit communication. Theoretical results establish a linear speedup convergence rate for nonconvex objectives without gradient heterogeneity assumptions and a PAC-Bayesian generalization bound that improves with decoupled weight decay. Empirically, FedAdamW consistently outperforms strong FL baselines on CNNs, Vision Transformers, Swin, and RoBERTa-LoRA tasks, demonstrating faster convergence and better test accuracy under non-iid data and in communication-constrained settings.

Abstract

AdamW has become one of the most effective optimizers for training large-scale models. We have also observed its effectiveness in the context of federated learning (FL). However, directly applying AdamW in federated learning settings poses significant challenges: (1) due to data heterogeneity, AdamW often yields high variance in the second-moment estimate $\boldsymbol{v}$; (2) the local overfitting of AdamW may cause client drift; and (3) Reinitializing moment estimates ($\boldsymbol{v}$, $\boldsymbol{m}$) at each round slows down convergence. To address these challenges, we propose the first \underline{Fed}erated \underline{AdamW} algorithm, called \texttt{FedAdamW}, for training and fine-tuning various large models. \texttt{FedAdamW} aligns local updates with the global update using both a \textbf{local correction mechanism} and decoupled weight decay to mitigate local overfitting. \texttt{FedAdamW} efficiently aggregates the \texttt{mean} of the second-moment estimates to reduce their variance and reinitialize them. Theoretically, we prove that \texttt{FedAdamW} achieves a linear speedup convergence rate of $\mathcal{O}(\sqrt{(L Δσ_l^2)/(S K R ε^2)}+(L Δ)/R)$ without \textbf{heterogeneity assumption}, where $S$ is the number of participating clients per round, $K$ is the number of local iterations, and $R$ is the total number of communication rounds. We also employ PAC-Bayesian generalization analysis to explain the effectiveness of decoupled weight decay in local training. Empirically, we validate the effectiveness of \texttt{FedAdamW} on language and vision Transformer models. Compared to several baselines, \texttt{FedAdamW} significantly reduces communication rounds and improves test accuracy. The code is available in https://github.com/junkangLiu0/FedAdamW.

FedAdamW: A Communication-Efficient Optimizer with Convergence and Generalization Guarantees for Federated Large Models

TL;DR

FedAdamW introduces a communication-efficient federated optimizer for training large models by combining a local correction mechanism, decoupled weight decay, and a block-wise second-moment aggregation strategy. The approach mitigates client drift, reduces variance in adaptive statistics, and leverages Hessian-informed partitioning to limit communication. Theoretical results establish a linear speedup convergence rate for nonconvex objectives without gradient heterogeneity assumptions and a PAC-Bayesian generalization bound that improves with decoupled weight decay. Empirically, FedAdamW consistently outperforms strong FL baselines on CNNs, Vision Transformers, Swin, and RoBERTa-LoRA tasks, demonstrating faster convergence and better test accuracy under non-iid data and in communication-constrained settings.

Abstract

AdamW has become one of the most effective optimizers for training large-scale models. We have also observed its effectiveness in the context of federated learning (FL). However, directly applying AdamW in federated learning settings poses significant challenges: (1) due to data heterogeneity, AdamW often yields high variance in the second-moment estimate ; (2) the local overfitting of AdamW may cause client drift; and (3) Reinitializing moment estimates (, ) at each round slows down convergence. To address these challenges, we propose the first \underline{Fed}erated \underline{AdamW} algorithm, called \texttt{FedAdamW}, for training and fine-tuning various large models. \texttt{FedAdamW} aligns local updates with the global update using both a \textbf{local correction mechanism} and decoupled weight decay to mitigate local overfitting. \texttt{FedAdamW} efficiently aggregates the \texttt{mean} of the second-moment estimates to reduce their variance and reinitialize them. Theoretically, we prove that \texttt{FedAdamW} achieves a linear speedup convergence rate of without \textbf{heterogeneity assumption}, where is the number of participating clients per round, is the number of local iterations, and is the total number of communication rounds. We also employ PAC-Bayesian generalization analysis to explain the effectiveness of decoupled weight decay in local training. Empirically, we validate the effectiveness of \texttt{FedAdamW} on language and vision Transformer models. Compared to several baselines, \texttt{FedAdamW} significantly reduces communication rounds and improves test accuracy. The code is available in https://github.com/junkangLiu0/FedAdamW.

Paper Structure

This paper contains 46 sections, 12 theorems, 65 equations, 8 figures, 11 tables, 5 algorithms.

Key Result

Theorem 1

Under Assumptions smoothness, bounded_stochastic_gradient_I, and bounded_stochastic_gradient_II, if we take $g^0=0$,$\beta_1=0,\lambda=0$ then FedAdamW converges as follows Here $G_0:=\frac{1}{N} \sum_{i=1}^N\left\|\nabla f_i\left(\boldsymbol{x}^0\right)\right\|^2$,$\Delta=f\left(\boldsymbol{x}^0\right)-f^{\star}$, $S$ is the number of participating clients per round, $K$ is the number of local i

Figures (8)

  • Figure 1: Performance of Local SGD and Local AdamW. For training ViT-Base, GPT2, and BERT liu2019roberta, we carefully tune the learning rate. For training all these Transformer models, Local SGD is still significantly worse than Local AdamW.
  • Figure 2: Training on CIFAR-100 using ViT-Tiny. (a) Data heterogeneity causes high variance in second-moment estimates across clients of Local AdamW. (b) Local AdamW suffers from more severe client drift than Local SGD under non-i.i.d. data.
  • Figure 3: (a–c):Block-wise Hessian structure of Transformer parameters under FL. Visualizing the Hessian submatrices of query, key, and value heads. The near block-diagonal structure supports block-wise second-moment aggregation in FedAdamW.
  • Figure 4: Illustration of FedAdamW’s block-wise aggregation strategy Clients estimate local second-moment statistics and send block-wise means to the server, reducing communication cost.
  • Figure 5: An illustration of local update in FedAdamW, which corrects client drift caused through global update guidance.
  • ...and 3 more figures

Theorems & Definitions (20)

  • Theorem 1: Convergence for non-convex functions
  • Theorem 2
  • Lemma 1
  • Lemma 2
  • proof
  • Lemma 3
  • proof
  • Lemma 4
  • proof
  • Lemma 5
  • ...and 10 more