Table of Contents
Fetching ...

SPAM: Stochastic Proximal Point Method with Momentum Variance Reduction for Non-convex Cross-Device Federated Learning

Avetik Karagulyan, Egor Shulgin, Abdurakhmon Sadiev, Peter Richtárik

TL;DR

This paper addresses cross-device federated learning with non-convex losses across billions of clients. It introduces SPAM, a framework that couples Momentum Variance Reduction on the server with a Stochastic Proximal Point method on the clients, and it extends to partial participation via SPAM-PP. The analysis establishes convergence under Hessian similarity without requiring smoothness, and shows an optimal communication-rate bound of $O(K^{-1/3})$ iterations to reach an $\varepsilon$-stationary point, with improved dependence on the Hessian similarity $\delta$ and gradient variance $\sigma$. Empirical results on a distributed ridge regression task corroborate the theory and illustrate robustness to inexact proximal computations. The work offers a flexible, state-free, locally solver-agnostic approach with significant implications for communication efficiency in large-scale cross-device FL.

Abstract

Cross-device training is a crucial subfield of federated learning, where the number of clients can reach into the billions. Standard approaches and local methods are prone to issues such as client drift and insensitivity to data similarities. We propose a novel algorithm (SPAM) for cross-device federated learning with non-convex losses, which solves both issues. We provide sharp analysis under second-order (Hessian) similarity, a condition satisfied by a variety of machine learning problems in practice. Additionally, we extend our results to the partial participation setting, where a cohort of selected clients communicate with the server at each communication round. Our method is the first in its kind, that does not require the smoothness of the objective and provably benefits from clients having similar data.

SPAM: Stochastic Proximal Point Method with Momentum Variance Reduction for Non-convex Cross-Device Federated Learning

TL;DR

This paper addresses cross-device federated learning with non-convex losses across billions of clients. It introduces SPAM, a framework that couples Momentum Variance Reduction on the server with a Stochastic Proximal Point method on the clients, and it extends to partial participation via SPAM-PP. The analysis establishes convergence under Hessian similarity without requiring smoothness, and shows an optimal communication-rate bound of iterations to reach an -stationary point, with improved dependence on the Hessian similarity and gradient variance . Empirical results on a distributed ridge regression task corroborate the theory and illustrate robustness to inexact proximal computations. The work offers a flexible, state-free, locally solver-agnostic approach with significant implications for communication efficiency in large-scale cross-device FL.

Abstract

Cross-device training is a crucial subfield of federated learning, where the number of clients can reach into the billions. Standard approaches and local methods are prone to issues such as client drift and insensitivity to data similarities. We propose a novel algorithm (SPAM) for cross-device federated learning with non-convex losses, which solves both issues. We provide sharp analysis under second-order (Hessian) similarity, a condition satisfied by a variety of machine learning problems in practice. Additionally, we extend our results to the partial participation setting, where a cohort of selected clients communicate with the server at each communication round. Our method is the first in its kind, that does not require the smoothness of the objective and provably benefits from clients having similar data.
Paper Structure (38 sections, 16 theorems, 98 equations, 1 figure, 1 table, 3 algorithms)

This paper contains 38 sections, 16 theorems, 98 equations, 1 figure, 1 table, 3 algorithms.

Key Result

Proposition 3.1

Let $x_k$ be the iterates of SPAM for an objective function $f$, which satisfies Assumptions as:sigma and as:similarity. If $\gamma_k^2 \leq \min\left\{\frac{1}{16\delta^2},\frac{p_k}{96\delta^2(1-p_k)}\right\}$, then for every $k \geq 1$ where $V_k$ is defined in eq:lyapunov.

Figures (1)

  • Figure 1: Convergence of SPAM-inexact on a ridge regression problem with different $p$ and $\gamma$.

Theorems & Definitions (19)

  • Proposition 3.1
  • Theorem 3.2: SPAM with constant parameters
  • Corollary 3.3
  • Corollary 3.4
  • Remark 3.1
  • Theorem 3.5: SPAM
  • Remark 3.2
  • Corollary 3.6: Optimal stepsize schedule
  • Definition 4.1: a-prox
  • Theorem 4.1: SPAM-inexact
  • ...and 9 more