Understanding Generalization of Federated Learning: the Trade-off between Model Stability and Optimization
Dun Zeng, Zheshun Wu, Shiyu Liu, Yu Pan, Xiaoying Tang, Zenglin Xu
TL;DR
This work addresses the generalization gap in federated learning under data heterogeneity by introducing Libra, a framework that jointly analyzes algorithm-dependent excess risk through model stability and gradient-norm dynamics. By deriving bounds on stability, gradient norms, and the minimum excess risk, Libra clarifies how hyperparameters such as global stepsize $\eta_g$, local steps $K$, and momentum $\beta$ trade off convergence speed against stability to influence generalization. Theoretical results reveal an explicit balance point leading to minimum excess risk and identify regimes corresponding to under-fitting, benign-fitting, and over-fitting, with practical guidance for hyperparameter tuning. Empirical evaluations on CIFAR-10/100 and NLP tasks corroborate the predicted trade-offs, demonstrating how learning-rate decay and momentum affect generalization, and offering a principled basis for designing FL algorithms with stronger generalization properties.
Abstract
Federated Learning (FL) is a distributed learning approach that trains machine learning models across multiple devices while keeping their local data private. However, FL often faces challenges due to data heterogeneity, leading to inconsistent local optima among clients. These inconsistencies can cause unfavorable convergence behavior and generalization performance degradation. Existing studies often describe this issue through \textit{convergence analysis} on gradient norms, focusing on how well a model fits training data, or through \textit{algorithmic stability}, which examines the generalization gap. However, neither approach precisely captures the generalization performance of FL algorithms, especially for non-convex neural network training. In response, this paper introduces an innovative generalization dynamics analysis framework, namely \textit{Libra}, for algorithm-dependent excess risk minimization, highlighting the trade-offs between model stability and gradient norms. We present Libra towards a standard federated optimization framework and its variants using server momentum. Through this framework, we show that larger local steps or momentum accelerate convergence of gradient norms, while worsening model stability, yielding better excess risk. Experimental results on standard FL settings prove the insights of our theories. These insights can guide hyperparameter tuning and future algorithm design to achieve stronger generalization.
