Table of Contents
Fetching ...

When Foresight Pruning Meets Zeroth-Order Optimization: Efficient Federated Learning for Low-Memory Devices

Pengyu Zhang, Yingjie Liu, Yingbo Zhou, Xiao Du, Xian Wei, Ting Wang, Mingsong Chen

TL;DR

This paper tackles memory limitations in federated learning for AIoT devices by introducing a memory-efficient foresight pruning method grounded in Neural Tangent Kernel (NTK) theory and designed to work with backpropagation-free (BP-Free) training. By approximating the federated NTK with local, sparse NTK matrices and enforcing data-free constraints to reduce heterogeneity-induced error, the method achieves substantial memory reductions (up to 9x) and lowers FLOPs while maintaining competitive accuracy. Key contributions include formalizing estimation-error bounds for Stein's Identity in this context, proposing an NTK-based pruning objective with saliency-driven masking, and integrating BP-Free training with a server-side perturbation scheme and a Random Seed Trick to minimize communication. The approach is validated on CIFAR-10/100 with realistic AIoT-level settings, including real test-bed experiments, demonstrating robust performance gains in memory, computation, and communication efficiency, particularly under highly non-IID data distributions.

Abstract

Although Federated Learning (FL) enables collaborative learning in Artificial Intelligence of Things (AIoT) design, it fails to work on low-memory AIoT devices due to its heavy memory usage. To address this problem, various federated pruning methods are proposed to reduce memory usage during inference. However, few of them can substantially mitigate the memory burdens during pruning and training. As an alternative, zeroth-order or backpropagation-free (BP-Free) methods can partially alleviate the memory consumption, but they suffer from scaling up and large computation overheads, since the gradient estimation error and floating point operations (FLOPs) increase as the dimensionality of the model parameters grows. In this paper, we propose a federated foresight pruning method based on Neural Tangent Kernel (NTK), which can seamlessly integrate with federated BP-Free training frameworks. We present an approximation to the computation of federated NTK by using the local NTK matrices. Moreover, we demonstrate that the data-free property of our method can substantially reduce the approximation error in extreme data heterogeneity scenarios. Since our approach improves the performance of the vanilla BP-Free method with fewer FLOPs and truly alleviates memory pressure during training and inference, it makes FL more friendly to low-memory devices. Comprehensive experimental results obtained from simulation- and real test-bed-based platforms show that our federated foresight-pruning method not only preserves the ability of the dense model with a memory reduction up to 9x but also boosts the performance of the vanilla BP-Free method with dramatically fewer FLOPs.

When Foresight Pruning Meets Zeroth-Order Optimization: Efficient Federated Learning for Low-Memory Devices

TL;DR

This paper tackles memory limitations in federated learning for AIoT devices by introducing a memory-efficient foresight pruning method grounded in Neural Tangent Kernel (NTK) theory and designed to work with backpropagation-free (BP-Free) training. By approximating the federated NTK with local, sparse NTK matrices and enforcing data-free constraints to reduce heterogeneity-induced error, the method achieves substantial memory reductions (up to 9x) and lowers FLOPs while maintaining competitive accuracy. Key contributions include formalizing estimation-error bounds for Stein's Identity in this context, proposing an NTK-based pruning objective with saliency-driven masking, and integrating BP-Free training with a server-side perturbation scheme and a Random Seed Trick to minimize communication. The approach is validated on CIFAR-10/100 with realistic AIoT-level settings, including real test-bed experiments, demonstrating robust performance gains in memory, computation, and communication efficiency, particularly under highly non-IID data distributions.

Abstract

Although Federated Learning (FL) enables collaborative learning in Artificial Intelligence of Things (AIoT) design, it fails to work on low-memory AIoT devices due to its heavy memory usage. To address this problem, various federated pruning methods are proposed to reduce memory usage during inference. However, few of them can substantially mitigate the memory burdens during pruning and training. As an alternative, zeroth-order or backpropagation-free (BP-Free) methods can partially alleviate the memory consumption, but they suffer from scaling up and large computation overheads, since the gradient estimation error and floating point operations (FLOPs) increase as the dimensionality of the model parameters grows. In this paper, we propose a federated foresight pruning method based on Neural Tangent Kernel (NTK), which can seamlessly integrate with federated BP-Free training frameworks. We present an approximation to the computation of federated NTK by using the local NTK matrices. Moreover, we demonstrate that the data-free property of our method can substantially reduce the approximation error in extreme data heterogeneity scenarios. Since our approach improves the performance of the vanilla BP-Free method with fewer FLOPs and truly alleviates memory pressure during training and inference, it makes FL more friendly to low-memory devices. Comprehensive experimental results obtained from simulation- and real test-bed-based platforms show that our federated foresight-pruning method not only preserves the ability of the dense model with a memory reduction up to 9x but also boosts the performance of the vanilla BP-Free method with dramatically fewer FLOPs.
Paper Structure (12 sections, 2 theorems, 16 equations, 5 figures, 6 tables, 1 algorithm)

This paper contains 12 sections, 2 theorems, 16 equations, 5 figures, 6 tables, 1 algorithm.

Key Result

Theorem 1

(Estimation error does_flzero) Let $\hat{\boldsymbol{\delta}}=\frac{1}{K}\sum_{k=1}^K\boldsymbol{\delta}_k$, covariance matrix $\widehat{\boldsymbol{\Sigma}}=\frac{1}{K\sigma^2}\sum_{k=1}^K\boldsymbol{\delta}_k\boldsymbol{\delta}_k^T$. Let $n$ be the dimension of trainable parameters $\mathbf{W}$ of

Figures (5)

  • Figure 1: Test accuracy comparison on CIFAR-10.
  • Figure 2: Test accuracy comparison on LeNet for CIFAR-10. $K$ is set to 200. The communication round is set to 3000.
  • Figure 3: Test accuracy for CIFAR-10 with different $K$.
  • Figure 4: Learning performance on the real test-bed.
  • Figure 5: Test accuracy for CIFAR-100 with different $K$.

Theorems & Definitions (3)

  • Theorem 1
  • Definition 1
  • Proposition 1