Table of Contents
Fetching ...

Order Optimal Bounds for One-Shot Federated Learning over non-Convex Loss Functions

Arsalan Sharifnassab, Saber Salehkaleybar, S. Jamaloddin Golestani

TL;DR

This work establishes fundamental limits for one-shot federated learning with non-convex losses by deriving a minimax lower bound on the excess population loss that depends on the number of machines, per-machine samples, and the per-machine communication budget. It then introduces the Multi-Resolution Estimator for Non-Convex loss (MRE-NC), a distributed algorithm that constructs a multi-resolution approximation of the population loss over a grid and uses this surrogate to recover an approximately minimizer, achieving an upper bound that matches the lower bound up to polylog factors in the large-sample regime. A parallel analysis shows a constant lower bound under a tiny communication budget, highlighting a fundamental limitation when B is fixed. Theoretical results are complemented by experiments on synthetic and real data (e.g., MNIST), illustrating that MRE-NC can outperform naive baselines and approach centralized performance in small-scale settings, while acknowledging scalability challenges in high dimensions. Overall, the paper provides a rigorous information-theoretic characterization of the trade-offs between communication, sample size, and dimensionality in one-shot federated learning for non-convex losses, and offers a concrete, near-optimal algorithm within this regime.

Abstract

We consider the problem of federated learning in a one-shot setting in which there are $m$ machines, each observing $n$ sample functions from an unknown distribution on non-convex loss functions. Let $F:[-1,1]^d\to\mathbb{R}$ be the expected loss function with respect to this unknown distribution. The goal is to find an estimate of the minimizer of $F$. Based on its observations, each machine generates a signal of bounded length $B$ and sends it to a server. The server collects signals of all machines and outputs an estimate of the minimizer of $F$. We show that the expected loss of any algorithm is lower bounded by $\max\big(1/(\sqrt{n}(mB)^{1/d}), 1/\sqrt{mn}\big)$, up to a logarithmic factor. We then prove that this lower bound is order optimal in $m$ and $n$ by presenting a distributed learning algorithm, called Multi-Resolution Estimator for Non-Convex loss function (MRE-NC), whose expected loss matches the lower bound for large $mn$ up to polylogarithmic factors.

Order Optimal Bounds for One-Shot Federated Learning over non-Convex Loss Functions

TL;DR

This work establishes fundamental limits for one-shot federated learning with non-convex losses by deriving a minimax lower bound on the excess population loss that depends on the number of machines, per-machine samples, and the per-machine communication budget. It then introduces the Multi-Resolution Estimator for Non-Convex loss (MRE-NC), a distributed algorithm that constructs a multi-resolution approximation of the population loss over a grid and uses this surrogate to recover an approximately minimizer, achieving an upper bound that matches the lower bound up to polylog factors in the large-sample regime. A parallel analysis shows a constant lower bound under a tiny communication budget, highlighting a fundamental limitation when B is fixed. Theoretical results are complemented by experiments on synthetic and real data (e.g., MNIST), illustrating that MRE-NC can outperform naive baselines and approach centralized performance in small-scale settings, while acknowledging scalability challenges in high dimensions. Overall, the paper provides a rigorous information-theoretic characterization of the trade-offs between communication, sample size, and dimensionality in one-shot federated learning for non-convex losses, and offers a concrete, near-optimal algorithm within this regime.

Abstract

We consider the problem of federated learning in a one-shot setting in which there are machines, each observing sample functions from an unknown distribution on non-convex loss functions. Let be the expected loss function with respect to this unknown distribution. The goal is to find an estimate of the minimizer of . Based on its observations, each machine generates a signal of bounded length and sends it to a server. The server collects signals of all machines and outputs an estimate of the minimizer of . We show that the expected loss of any algorithm is lower bounded by , up to a logarithmic factor. We then prove that this lower bound is order optimal in and by presenting a distributed learning algorithm, called Multi-Resolution Estimator for Non-Convex loss function (MRE-NC), whose expected loss matches the lower bound for large up to polylogarithmic factors.

Paper Structure

This paper contains 32 sections, 18 theorems, 161 equations, 5 figures, 1 algorithm.

Key Result

Theorem 1

For any $\mathcal{C}\ge1$, any $m\ge M_\mathcal{C}$, and any estimator with output denoted by $\hat{\theta}$, there exists a distribution $P$ and corresponding function $F$ defined in eq:def of the loss F, for which with probability at least $1/2$,

Figures (5)

  • Figure 1: The considered distributed system consists of $m$ identical machines, each observing $n$ independent sample functions from an unknown distribution $P$. Each machine $i$ sends signal $Y_i$ of length $B$ bits to a server. The sever collects all the signals and returns an estimate $\hat{\theta}$ for the optimization problem in \ref{['eq:opt problem']}.
  • Figure 2: An illustration of a $p$-point in $[-1,1]^d$ for $d=2$. The point $p$ belongs to $G^2$ and $p'$ is the parent of $p$.
  • Figure 3: Illustrations of functions $h$ and $f_\sigma$ for $d=2$. (a) shows the surface of $h(\cdot)$ defined in \ref{['eqa:def h']} and (b) is an example of $f_\sigma(\cdot)$ defined in \ref{['eqa:fsigma']}.
  • Figure 4: Comparison of the performance of MRE-NC with two naive approaches. The number of parameters ($d$) and the number of samples per machine ($n$) are $6$ and $10$, respectively.
  • Figure 5: The performance of MRE-NC (loss and classification error) on classifying digits in MNIST dataset against the number of machines. The left and right y-axes correspond to the true loss function and the classification error of the obtained model, respectively. The number of parameters per weak learner and the number of samples per machine ($n$) are $3$ and $10$, respectively.

Theorems & Definitions (21)

  • Theorem 1
  • Corollary 1
  • Theorem 2
  • Corollary 2
  • Remark 1
  • Proposition 1
  • Proposition 2
  • Proposition 3
  • Lemma 1
  • Lemma 2
  • ...and 11 more