Cohort Squeeze: Beyond a Single Communication Round per Cohort in Cross-Device Federated Learning
Kai Yi, Timur Kharisov, Igor Sokolov, Peter Richtárik
TL;DR
This paper tackles the high-communication cost of cross-device federated learning by challenging the traditional single-round-per-cohort paradigm. It introduces 0.90SPPM-AS, a stochastic proximal point method with arbitrary cohort sampling that allows multiple local proximal updates within each global iteration, reducing total communication cost while preserving convergence guarantees. The authors develop a unified theory around sampling distributions, proximal updates, and iteration complexity, and instantiate NICE, BS, and SS sampling schemes—with stratified sampling often yielding the best variance properties. Extensive experiments on convex logistic regression and non-convex neural networks (including FEMNIST) show up to 74% communication-cost reduction in standard FL and even higher savings in hierarchical FL, validating both the approach and the practical tuning guidelines. Overall, the work provides a principled path to more communication-efficient cross-device FL by leveraging flexible sampling and multi-round cohort interactions.
Abstract
Virtually all federated learning (FL) methods, including FedAvg, operate in the following manner: i) an orchestrating server sends the current model parameters to a cohort of clients selected via certain rule, ii) these clients then independently perform a local training procedure (e.g., via SGD or Adam) using their own training data, and iii) the resulting models are shipped to the server for aggregation. This process is repeated until a model of suitable quality is found. A notable feature of these methods is that each cohort is involved in a single communication round with the server only. In this work we challenge this algorithmic design primitive and investigate whether it is possible to ``squeeze more juice" out of each cohort than what is possible in a single communication round. Surprisingly, we find that this is indeed the case, and our approach leads to up to 74% reduction in the total communication cost needed to train a FL model in the cross-device setting. Our method is based on a novel variant of the stochastic proximal point method (SPPM-AS) which supports a large collection of client sampling procedures some of which lead to further gains when compared to classical client selection approaches.
