Table of Contents
Fetching ...

FedAPTA: Federated Multi-task Learning for Heterogeneous Devices with Adaptive Layer-wise Pruning and Task-aware Aggregation

Zhen Yu, Yachao Yuan, Jin Wang, Zhipeng Cheng, Jianhua Hu

TL;DR

FedAPTA tackles federated multi-task learning on heterogeneous devices by combining adaptive layer-wise pruning with a heterogeneous model recovery mechanism and task-aware, cluster-based aggregation. It prunes each device's local model at the layer level, recovers full architectures using the latest global model to enable aggregation, and clusters devices by task using cosine-based similarities with HDBSCAN to perform per-task updates. Empirical results across five datasets and two architectures show FedAPTA outperforms nine SOTA FL methods, maintaining high accuracy even with substantial pruning. This approach enables scalable, privacy-preserving collaboration on diverse devices while preserving task-specific knowledge transfer.

Abstract

Federated Learning (FL) has shown considerable promise in Machine Learning (ML) across numerous devices for privacy protection, efficient data utilization, and dynamic collaboration. However, mobile devices typically have limited and heterogeneous computational capabilities, and different devices may even have different tasks. This client heterogeneity is a major bottleneck hindering the practical application of FL. Existing work mainly focuses on mitigating FL's computation and communication overhead of a single task while overlooking the computing resource heterogeneity issue of different devices in FL. To tackle this, we design FedAPTA, a federated multi-task learning framework. FedAPTA overcomes computing resource heterogeneity through the developed layer-wise model pruning technique, which reduces local model size while considering both data and device heterogeneity. To aggregate structurally heterogeneous local models of different tasks, we introduce a heterogeneous model recovery strategy and a task-aware model aggregation method that enables the aggregation through infilling local model architecture with the shared global model and clustering local models according to their specific tasks. We deploy FedAPTA on a realistic FL platform and benchmark it against nine SOTA FL methods. The experimental outcomes demonstrate that the proposed FedAPTA considerably outperforms the state-of-the-art FL methods by up to 4.23\%. Our code is available at https://github.com/Zhenzovo/FedAPTA.

FedAPTA: Federated Multi-task Learning for Heterogeneous Devices with Adaptive Layer-wise Pruning and Task-aware Aggregation

TL;DR

FedAPTA tackles federated multi-task learning on heterogeneous devices by combining adaptive layer-wise pruning with a heterogeneous model recovery mechanism and task-aware, cluster-based aggregation. It prunes each device's local model at the layer level, recovers full architectures using the latest global model to enable aggregation, and clusters devices by task using cosine-based similarities with HDBSCAN to perform per-task updates. Empirical results across five datasets and two architectures show FedAPTA outperforms nine SOTA FL methods, maintaining high accuracy even with substantial pruning. This approach enables scalable, privacy-preserving collaboration on diverse devices while preserving task-specific knowledge transfer.

Abstract

Federated Learning (FL) has shown considerable promise in Machine Learning (ML) across numerous devices for privacy protection, efficient data utilization, and dynamic collaboration. However, mobile devices typically have limited and heterogeneous computational capabilities, and different devices may even have different tasks. This client heterogeneity is a major bottleneck hindering the practical application of FL. Existing work mainly focuses on mitigating FL's computation and communication overhead of a single task while overlooking the computing resource heterogeneity issue of different devices in FL. To tackle this, we design FedAPTA, a federated multi-task learning framework. FedAPTA overcomes computing resource heterogeneity through the developed layer-wise model pruning technique, which reduces local model size while considering both data and device heterogeneity. To aggregate structurally heterogeneous local models of different tasks, we introduce a heterogeneous model recovery strategy and a task-aware model aggregation method that enables the aggregation through infilling local model architecture with the shared global model and clustering local models according to their specific tasks. We deploy FedAPTA on a realistic FL platform and benchmark it against nine SOTA FL methods. The experimental outcomes demonstrate that the proposed FedAPTA considerably outperforms the state-of-the-art FL methods by up to 4.23\%. Our code is available at https://github.com/Zhenzovo/FedAPTA.

Paper Structure

This paper contains 19 sections, 5 equations, 3 figures, 6 tables, 3 algorithms.

Figures (3)

  • Figure 1: Overview of the proposed FedAPTA framework. To elaborate, ➀ After the global model is sent from the central server, the devices prune the global model by the adaptive layer-wise pruning to obtain the local model. ➁ The devices use its local data to train the local model. ➂ The devices upload the trained local model and the mask matrix to the central server. ➃ The central server utilizes global model information to recover the received local model. ➄ The central server clusters all local models using the distance matrix based on their updates. ➅ The central server aggregates the local models belonging to each task to obtain the corresponding global models. ➆ The central server distributes the global models belonging to each task to the corresponding devices.
  • Figure 2: Model accuracy of FedAPTA, FedLPS jia2024fedlps, FedDrop caldas2018expanding_feddrop, and FedRolex alam2022fedrolex on the RseNet18 model with different pruning ratios on the non-i.i.d. setting of the SVHN, CIFAR10, and EMNIST datasets.
  • Figure 3: Effects of model similarity metrics, where device 0 and device 1 handle the same task. Cosine similarity best captures the relationship between devices 0 and 1 since it produces the most distinguishable similarity values among all devices.