Table of Contents
Fetching ...

Task Diversity in Bayesian Federated Learning: Simultaneous Processing of Classification and Regression

Junliang Lyu, Yixuan Zhang, Xiaoling Lu, Feng Zhou

TL;DR

The paper tackles the limitation of Federated Learning by handling heterogeneous local tasks (classification and regression) through a Bayesian, task-aware framework. It introduces pFed-Mul, which uses multi-output Gaussian processes at the client and aggregates posterior information on a global MOGP prior, with Polya-Gamma augmentation enabling analytic mean-field variational inference. The approach yields superior predictive performance, uncertainty calibration, and OOD detection, while achieving faster convergence thanks to the augmentation and inducing-point scalable inference. The method demonstrates strong results on synthetic and real data (CelebA, Dogcat) and includes comprehensive ablations and analysis, highlighting its potential for diverse, privacy-preserving, on-device learning scenarios. Code is publicly available to facilitate adoption and further research.

Abstract

This work addresses a key limitation in current federated learning approaches, which predominantly focus on homogeneous tasks, neglecting the task diversity on local devices. We propose a principled integration of multi-task learning using multi-output Gaussian processes (MOGP) at the local level and federated learning at the global level. MOGP handles correlated classification and regression tasks, offering a Bayesian non-parametric approach that naturally quantifies uncertainty. The central server aggregates the posteriors from local devices, updating a global MOGP prior redistributed for training local models until convergence. Challenges in performing posterior inference on local devices are addressed through the Pólya-Gamma augmentation technique and mean-field variational inference, enhancing computational efficiency and convergence rate. Experimental results on both synthetic and real data demonstrate superior predictive performance, OOD detection, uncertainty calibration and convergence rate, highlighting the method's potential in diverse applications. Our code is publicly available at https://github.com/JunliangLv/task_diversity_BFL.

Task Diversity in Bayesian Federated Learning: Simultaneous Processing of Classification and Regression

TL;DR

The paper tackles the limitation of Federated Learning by handling heterogeneous local tasks (classification and regression) through a Bayesian, task-aware framework. It introduces pFed-Mul, which uses multi-output Gaussian processes at the client and aggregates posterior information on a global MOGP prior, with Polya-Gamma augmentation enabling analytic mean-field variational inference. The approach yields superior predictive performance, uncertainty calibration, and OOD detection, while achieving faster convergence thanks to the augmentation and inducing-point scalable inference. The method demonstrates strong results on synthetic and real data (CelebA, Dogcat) and includes comprehensive ablations and analysis, highlighting its potential for diverse, privacy-preserving, on-device learning scenarios. Code is publicly available to facilitate adoption and further research.

Abstract

This work addresses a key limitation in current federated learning approaches, which predominantly focus on homogeneous tasks, neglecting the task diversity on local devices. We propose a principled integration of multi-task learning using multi-output Gaussian processes (MOGP) at the local level and federated learning at the global level. MOGP handles correlated classification and regression tasks, offering a Bayesian non-parametric approach that naturally quantifies uncertainty. The central server aggregates the posteriors from local devices, updating a global MOGP prior redistributed for training local models until convergence. Challenges in performing posterior inference on local devices are addressed through the Pólya-Gamma augmentation technique and mean-field variational inference, enhancing computational efficiency and convergence rate. Experimental results on both synthetic and real data demonstrate superior predictive performance, OOD detection, uncertainty calibration and convergence rate, highlighting the method's potential in diverse applications. Our code is publicly available at https://github.com/JunliangLv/task_diversity_BFL.

Paper Structure

This paper contains 40 sections, 11 equations, 5 figures, 2 tables, 2 algorithms.

Figures (5)

  • Figure 1: The overview of our model pFed-Mul. Left: System diagram. The central server aggregates the posteriors from local devices, updating a global MOGP prior redistributed for training local models. Right: Bi-level optimization. The subfigure illustrates an iterative application of mean-field VI at the local level and hyperparameter tuning at the global level.
  • Figure 2: The estimated posterior of latent functions from pFed-Mul and pFed-St on one client. pFed-Mul, achieves a better fit, especially for classification. Compared with pFed-St, pFed-Mul enables the transfer of knowledge from other task types, effectively reducing uncertainty, i.e. posterior variance (orange areas).
  • Figure 3: Reliability diagrams for all methods. We plot the perfect calibration as blue diagonals, and practical result as orange bars. The disparity between the top of orange bars and blue line represents the degree of calibration, with the expected calibration error (ECE) calculated for comparison and placed in the top-left corner of diagrams. Our proposed method, pFed-Mul, demonstrates best calibration performance, ranking first in terms of ECE.
  • Figure 4: OOD detection for CelebA and Dogcat. The predictive mean and variance of latent functions are depicted by blue lines and red areas beneath each image respectively. Positions where the image is masked as an OOD sample are denoted by black stars. A greater variance (wider area) is observed for OOD samples.
  • Figure 5: Convergence rate of all models on both datasets. pFed-Mul consistently converges to a comparable test accuracy plateau with a remarkable convergence rate.