Bayesian Federated Learning with Hamiltonian Monte Carlo: Algorithm and Theory
Jiajun Liang, Qian Zhang, Wei Deng, Qifan Song, Guang Lin
TL;DR
This paper tackles uncertainty-aware Bayesian federated learning on non-iid data by proposing FA-HMC, a Federated Averaging approach built on stochastic gradient Hamiltonian Monte Carlo to sample from the global posterior π(θ) ∝ exp(−f(θ)) with $f(θ)=\sum_{c=1}^N w_c f^{(c)}(θ)$. It derives non-asymptotic Wasserstein-2 convergence guarantees under μ-strong convexity and Hessian smoothness, and shows how dimension $d$, gradient noise $σ_g$, momentum correlation $ρ$, and local update frequency influence convergence and communication costs; the analysis also establishes tightness via lower bounds. Empirically, FA-HMC outperforms FA-LD on simulated Bayesian logistic regression and real datasets (Fashion-MNIST, KMNIST, CIFAR-2), while achieving lower communication overhead and providing uncertainty quantification. The results suggest FA-HMC is robust to hyperparameters and suitable for privacy-conscious federated settings, with potential extensions to non-convex settings and heterogeneous local dynamics.
Abstract
This work introduces a novel and efficient Bayesian federated learning algorithm, namely, the Federated Averaging stochastic Hamiltonian Monte Carlo (FA-HMC), for parameter estimation and uncertainty quantification. We establish rigorous convergence guarantees of FA-HMC on non-iid distributed data sets, under the strong convexity and Hessian smoothness assumptions. Our analysis investigates the effects of parameter space dimension, noise on gradients and momentum, and the frequency of communication (between the central node and local nodes) on the convergence and communication costs of FA-HMC. Beyond that, we establish the tightness of our analysis by showing that the convergence rate cannot be improved even for continuous FA-HMC process. Moreover, extensive empirical studies demonstrate that FA-HMC outperforms the existing Federated Averaging-Langevin Monte Carlo (FA-LD) algorithm.
