Table of Contents
Fetching ...

Sparse-ProxSkip: Accelerated Sparse-to-Sparse Training in Federated Learning

Georg Meinhardt, Kai Yi, Laurent Condat, Peter Richtárik

TL;DR

This work targets Federated Learning under tight client resources and high communication costs. It shows that naive combinations of acceleration with sparse-to-sparse training fail in FL due to drift control issues and proposes Sparse-ProxSkip, which fuses local STE pruning with accelerated proximal updates while enforcing a zero-sum control variate condition. Building on RandProx theory, it extends to nonconvex sparsity via hard TopK pruning and introduces two variants, 0.96Sparse-ProxSkip and 0.96Sparse-ProxSkip-Local, with local steps set to $k=\frac{1}{p}$. Across BlogFeedback, FEMNIST, and CIFAR-10, the method delivers superior final accuracy and dramatically reduced uplink communication compared to baselines, illustrating a practical path to scalable sparse FL training.

Abstract

In Federated Learning (FL), both client resource constraints and communication costs pose major problems for training large models. In the centralized setting, sparse training addresses resource constraints, while in the distributed setting, local training addresses communication costs. Recent work has shown that local training provably improves communication complexity through acceleration. In this work we show that in FL, naive integration of sparse training and acceleration fails, and we provide theoretical and empirical explanations of this phenomenon. We introduce Sparse-ProxSkip, addressing the issue and implementing the efficient technique of Straight-Through Estimator pruning into sparse training. We demonstrate the performance of Sparse-ProxSkip in extensive experiments.

Sparse-ProxSkip: Accelerated Sparse-to-Sparse Training in Federated Learning

TL;DR

This work targets Federated Learning under tight client resources and high communication costs. It shows that naive combinations of acceleration with sparse-to-sparse training fail in FL due to drift control issues and proposes Sparse-ProxSkip, which fuses local STE pruning with accelerated proximal updates while enforcing a zero-sum control variate condition. Building on RandProx theory, it extends to nonconvex sparsity via hard TopK pruning and introduces two variants, 0.96Sparse-ProxSkip and 0.96Sparse-ProxSkip-Local, with local steps set to . Across BlogFeedback, FEMNIST, and CIFAR-10, the method delivers superior final accuracy and dramatically reduced uplink communication compared to baselines, illustrating a practical path to scalable sparse FL training.

Abstract

In Federated Learning (FL), both client resource constraints and communication costs pose major problems for training large models. In the centralized setting, sparse training addresses resource constraints, while in the distributed setting, local training addresses communication costs. Recent work has shown that local training provably improves communication complexity through acceleration. In this work we show that in FL, naive integration of sparse training and acceleration fails, and we provide theoretical and empirical explanations of this phenomenon. We introduce Sparse-ProxSkip, addressing the issue and implementing the efficient technique of Straight-Through Estimator pruning into sparse training. We demonstrate the performance of Sparse-ProxSkip in extensive experiments.
Paper Structure (19 sections, 13 equations, 8 figures, 5 tables)

This paper contains 19 sections, 13 equations, 8 figures, 5 tables.

Figures (8)

  • Figure 1: On the left, test score for regression on the Blog Feedback dataset buza2013feedback. Our method performs best in both final score and communication efficiency. On the right, test accuracy for ResNet18 he2016deep on CIFAR-10 cifar10. Our method 0.96Sparse-ProxSkip prevents catastrophic failure occurring when combining acceleration and pruning at the server. The shaded area in both plots represents the standard error.
  • Figure 2: Test Score ($R^2$) on the left and train loss on the right for regression on the Blog Feedback dataset buza2013feedback. Baseline methods are dashed while our methods are solid. We observe that both 0.96RandProx-l$_1$ and our proposed methods converge to a better solution in a substantially more communication efficient way. The shaded area in the figures represents the standard error. Error bars for all experiments are included but are sometimes not visible, due to deterministic initialization at $w_{i,0} = \mathbf{0}$.
  • Figure 3: Results for logistic regression on FEMNIST at $99 \%$ sparsity. 0.96Sparse-ProxSkip and 0.96Sparse-ProxSkip-Local outperform all baselines both in communication costs and final accuracy. The shaded area in the figures represents the standard error.
  • Figure 4: Results for ResNet18 he2016deep on CIFAR10 cifar10 at $90 \%$ sparsity. 0.96Sparse-ProxSkip still outperforms the baselines, to a lesser degree though. The main observation is that 0.96Accelerated-Server-Pruning fails completely in accuracy and loss because of $|\sum_i h_i| \gg 0$ and that the proposed fixes of 0.96Sparse-ProxSkip address this problem. The shaded area in the figures represents the standard error.
  • Figure 5: Distribution of the client sizes in the Federated version of the Blog Feedback dataset buza2013feedback.
  • ...and 3 more figures