Order Optimal Bounds for One-Shot Federated Learning over non-Convex Loss Functions
Arsalan Sharifnassab, Saber Salehkaleybar, S. Jamaloddin Golestani
TL;DR
This work establishes fundamental limits for one-shot federated learning with non-convex losses by deriving a minimax lower bound on the excess population loss that depends on the number of machines, per-machine samples, and the per-machine communication budget. It then introduces the Multi-Resolution Estimator for Non-Convex loss (MRE-NC), a distributed algorithm that constructs a multi-resolution approximation of the population loss over a grid and uses this surrogate to recover an approximately minimizer, achieving an upper bound that matches the lower bound up to polylog factors in the large-sample regime. A parallel analysis shows a constant lower bound under a tiny communication budget, highlighting a fundamental limitation when B is fixed. Theoretical results are complemented by experiments on synthetic and real data (e.g., MNIST), illustrating that MRE-NC can outperform naive baselines and approach centralized performance in small-scale settings, while acknowledging scalability challenges in high dimensions. Overall, the paper provides a rigorous information-theoretic characterization of the trade-offs between communication, sample size, and dimensionality in one-shot federated learning for non-convex losses, and offers a concrete, near-optimal algorithm within this regime.
Abstract
We consider the problem of federated learning in a one-shot setting in which there are $m$ machines, each observing $n$ sample functions from an unknown distribution on non-convex loss functions. Let $F:[-1,1]^d\to\mathbb{R}$ be the expected loss function with respect to this unknown distribution. The goal is to find an estimate of the minimizer of $F$. Based on its observations, each machine generates a signal of bounded length $B$ and sends it to a server. The server collects signals of all machines and outputs an estimate of the minimizer of $F$. We show that the expected loss of any algorithm is lower bounded by $\max\big(1/(\sqrt{n}(mB)^{1/d}), 1/\sqrt{mn}\big)$, up to a logarithmic factor. We then prove that this lower bound is order optimal in $m$ and $n$ by presenting a distributed learning algorithm, called Multi-Resolution Estimator for Non-Convex loss function (MRE-NC), whose expected loss matches the lower bound for large $mn$ up to polylogarithmic factors.
