Table of Contents
Fetching ...

Personalized Bayesian Federated Learning with Wasserstein Barycenter Aggregation

Ting Wei, Biao Mei, Junliang Lyu, Renquan Zhang, Feng Zhou, Yifan Sun

TL;DR

This work tackles personalized Bayesian federated learning under non-i.i.d. client data by marrying nonparametric local posterior inference with geometry-aware global aggregation. It introduces FedWBA, which uses particle-based variational inference (SVGD) to represent local posteriors and a particle-based Wasserstein barycenter to aggregate these posteriors into a global prior, with kernel density estimation ensuring a continuous prior for SVGD updates. The authors prove local ELBO growth and global barycenter consistency, and empirically show improved prediction accuracy, uncertainty calibration (lower ECE), and faster convergence across multiple datasets and client regimes. The approach offers a principled, uncertainty-aware alternative to Gaussian-posteriors and simple averaging, with clear implications for safety-critical, privacy-preserving distributed learning.

Abstract

Personalized Bayesian federated learning (PBFL) handles non-i.i.d. client data and quantifies uncertainty by combining personalization with Bayesian inference. However, existing PBFL methods face two limitations: restrictive parametric assumptions in client posterior inference and naive parameter averaging for server aggregation. To overcome these issues, we propose FedWBA, a novel PBFL method that enhances both local inference and global aggregation. At the client level, we use particle-based variational inference for nonparametric posterior representation. At the server level, we introduce particle-based Wasserstein barycenter aggregation, offering a more geometrically meaningful approach. Theoretically, we provide local and global convergence guarantees for FedWBA. Locally, we prove a KL divergence decrease lower bound per iteration for variational inference convergence. Globally, we show that the Wasserstein barycenter converges to the true parameter as the client data size increases. Empirically, experiments show that FedWBA outperforms baselines in prediction accuracy, uncertainty calibration, and convergence rate, with ablation studies confirming its robustness.

Personalized Bayesian Federated Learning with Wasserstein Barycenter Aggregation

TL;DR

This work tackles personalized Bayesian federated learning under non-i.i.d. client data by marrying nonparametric local posterior inference with geometry-aware global aggregation. It introduces FedWBA, which uses particle-based variational inference (SVGD) to represent local posteriors and a particle-based Wasserstein barycenter to aggregate these posteriors into a global prior, with kernel density estimation ensuring a continuous prior for SVGD updates. The authors prove local ELBO growth and global barycenter consistency, and empirically show improved prediction accuracy, uncertainty calibration (lower ECE), and faster convergence across multiple datasets and client regimes. The approach offers a principled, uncertainty-aware alternative to Gaussian-posteriors and simple averaging, with clear implications for safety-critical, privacy-preserving distributed learning.

Abstract

Personalized Bayesian federated learning (PBFL) handles non-i.i.d. client data and quantifies uncertainty by combining personalization with Bayesian inference. However, existing PBFL methods face two limitations: restrictive parametric assumptions in client posterior inference and naive parameter averaging for server aggregation. To overcome these issues, we propose FedWBA, a novel PBFL method that enhances both local inference and global aggregation. At the client level, we use particle-based variational inference for nonparametric posterior representation. At the server level, we introduce particle-based Wasserstein barycenter aggregation, offering a more geometrically meaningful approach. Theoretically, we provide local and global convergence guarantees for FedWBA. Locally, we prove a KL divergence decrease lower bound per iteration for variational inference convergence. Globally, we show that the Wasserstein barycenter converges to the true parameter as the client data size increases. Empirically, experiments show that FedWBA outperforms baselines in prediction accuracy, uncertainty calibration, and convergence rate, with ablation studies confirming its robustness.

Paper Structure

This paper contains 33 sections, 6 theorems, 50 equations, 8 figures, 8 tables, 1 algorithm.

Key Result

Theorem 5.3

Under ass:1ass:2, given SVGD iteration $l$, with client $k$ scheduled, the increase in the ELBO from iteration $l$ to $l+1$ satisfies the inequality: where $\epsilon$ is the step size in SVGD, $\tilde{p}(\boldsymbol{\theta}\mid \mathcal{D}_k)$ denotes the unnormalized posterior distribution and $D(q,p)$ stands for the kernelized Stein discrepancy:

Figures (8)

  • Figure 1: Overview of FedWBA. Left: System diagram. Clients upload local posterior particles to server for aggregation, server updates global prior particles and redistributes them to clients. Right: Local posterior particles from maximizing local ELBO, global prior particles as Wasserstein barycenter of $K$ local posteriors.
  • Figure 2: Reliability diagrams of top four methods on CIFAR-100. The perfect calibration is plotted as a red diagonal, and the actual results are presented as bar charts. The gap between the top of each bar and the red line represents the calibration error. The ECE is calculated and placed in the corner of the figure. FedWBA demonstrates the best calibration performance, ranking first in terms of ECE.
  • Figure 3: Comparison of convergence rates of different methods on MNIST, FMNIST, CIFAR-10, and CIFAR-100 with 50 clients. FedWBA exhibits the fastest convergence, with rapid growth in the first 10 communication rounds followed by steady improvement.
  • Figure 4: Comparison of Three Aggregation Methods: Wasserstein Barycenter (WB), Parameter Averaging (Avg), and Arithmetic Mean (Mixture).
  • Figure 5: Reliability diagrams of the six methods on CIFAR-100. The perfect calibration is plotted as a red diagonal line, and the actual results are presented as bar charts. The gap between the top of each bar and the red line represents the calibration error. The ECE is calculated and placed in the top-left corner of the figure. Among them, the method with the highest ECE value has the worst calibration performance.
  • ...and 3 more figures

Theorems & Definitions (7)

  • Theorem 5.3
  • Theorem 5.9
  • Definition E.1
  • Lemma E.2
  • Lemma E.3
  • Lemma E.4
  • Lemma E.5