Table of Contents
Fetching ...

Federated Bayesian Deep Learning: The Application of Statistical Aggregation Methods to Bayesian Models

John Fischer, Marko Orescanin, Justin Loomis, Patrick McClure

TL;DR

Federated Bayesian Deep Learning (FBDL) investigates how to aggregate Bayesian neural network models across distributed clients in a federated setting. The study redefines statistical aggregation for mean-field VI and MC-dropout Bayesian models, evaluating six aggregation strategies (NWA, WS, LP, Conflation, WC, DWC) on IID and non-IID CIFAR-10 partitions using a fully variational ResNet-20. It also compares MC dropout as a lightweight Bayesian alternative and analyzes multiple client-weighting schemes and prior-update policies. The results show aggregation choice substantially affects accuracy, calibration, and training efficiency; WS/WC/Conflation generally outperform NWA/LP, with DWC benefiting from pretraining, and MC dropout offering a practical baseline with favorable compute/communication tradeoffs. The work provides practical guidance for deploying Bayesian FL in remote sensing and safety-critical domains and highlights directions for future work on larger, real-world datasets and dynamic priors.

Abstract

Federated learning (FL) is an approach to training machine learning models that takes advantage of multiple distributed datasets while maintaining data privacy and reducing communication costs associated with sharing local datasets. Aggregation strategies have been developed to pool or fuse the weights and biases of distributed deterministic models; however, modern deterministic deep learning (DL) models are often poorly calibrated and lack the ability to communicate a measure of epistemic uncertainty in prediction, which is desirable for remote sensing platforms and safety-critical applications. Conversely, Bayesian DL models are often well calibrated and capable of quantifying and communicating a measure of epistemic uncertainty along with a competitive prediction accuracy. Unfortunately, because the weights and biases in Bayesian DL models are defined by a probability distribution, simple application of the aggregation methods associated with FL schemes for deterministic models is either impossible or results in sub-optimal performance. In this work, we use independent and identically distributed (IID) and non-IID partitions of the CIFAR-10 dataset and a fully variational ResNet-20 architecture to analyze six different aggregation strategies for Bayesian DL models. Additionally, we analyze the traditional federated averaging approach applied to an approximate Bayesian Monte Carlo dropout model as a lightweight alternative to more complex variational inference methods in FL. We show that aggregation strategy is a key hyperparameter in the design of a Bayesian FL system with downstream effects on accuracy, calibration, uncertainty quantification, training stability, and client compute requirements.

Federated Bayesian Deep Learning: The Application of Statistical Aggregation Methods to Bayesian Models

TL;DR

Federated Bayesian Deep Learning (FBDL) investigates how to aggregate Bayesian neural network models across distributed clients in a federated setting. The study redefines statistical aggregation for mean-field VI and MC-dropout Bayesian models, evaluating six aggregation strategies (NWA, WS, LP, Conflation, WC, DWC) on IID and non-IID CIFAR-10 partitions using a fully variational ResNet-20. It also compares MC dropout as a lightweight Bayesian alternative and analyzes multiple client-weighting schemes and prior-update policies. The results show aggregation choice substantially affects accuracy, calibration, and training efficiency; WS/WC/Conflation generally outperform NWA/LP, with DWC benefiting from pretraining, and MC dropout offering a practical baseline with favorable compute/communication tradeoffs. The work provides practical guidance for deploying Bayesian FL in remote sensing and safety-critical domains and highlights directions for future work on larger, real-world datasets and dynamic priors.

Abstract

Federated learning (FL) is an approach to training machine learning models that takes advantage of multiple distributed datasets while maintaining data privacy and reducing communication costs associated with sharing local datasets. Aggregation strategies have been developed to pool or fuse the weights and biases of distributed deterministic models; however, modern deterministic deep learning (DL) models are often poorly calibrated and lack the ability to communicate a measure of epistemic uncertainty in prediction, which is desirable for remote sensing platforms and safety-critical applications. Conversely, Bayesian DL models are often well calibrated and capable of quantifying and communicating a measure of epistemic uncertainty along with a competitive prediction accuracy. Unfortunately, because the weights and biases in Bayesian DL models are defined by a probability distribution, simple application of the aggregation methods associated with FL schemes for deterministic models is either impossible or results in sub-optimal performance. In this work, we use independent and identically distributed (IID) and non-IID partitions of the CIFAR-10 dataset and a fully variational ResNet-20 architecture to analyze six different aggregation strategies for Bayesian DL models. Additionally, we analyze the traditional federated averaging approach applied to an approximate Bayesian Monte Carlo dropout model as a lightweight alternative to more complex variational inference methods in FL. We show that aggregation strategy is a key hyperparameter in the design of a Bayesian FL system with downstream effects on accuracy, calibration, uncertainty quantification, training stability, and client compute requirements.
Paper Structure (32 sections, 21 equations, 9 figures, 5 tables, 1 algorithm)

This paper contains 32 sections, 21 equations, 9 figures, 5 tables, 1 algorithm.

Figures (9)

  • Figure 1: ResNet Model Block Structure. (a) depicts the structure of a ResNet block in the deterministic configuration. (b) depicts the structure of the same ResNet block in the MC dropout configuration with the addition of a dropout layer after each convolutional and linear layer. These dropout layers are enabled during the inference phase. (c) depicts the ResNet block in the Flipout configuration. Each convolutional and linear layer is replaced by the Tensorflow Probability implementation of the corresponding Flipout Layer. Changes from the deterministic configuration are highlighted in orange.
  • Figure 2: Uncertainty calibration plots for $E=1$. VI aggregation method comparison plot of test set accuracy vs. ratio of data retained for IID and non-IID 2-class data distributions. Non-IID Dirichlet distribution plots show trends and relationships that mirror IID distribution results. Normalized entropy and epistemic uncertainty plots are shown, aleatoric uncertainty plots show trends and relationships that mirror normalized entropy.
  • Figure 3: Learning curve plots for $E=1$. Plot of test set accuracy vs. federation round for VI aggregation methods for each data distribution.
  • Figure 4: Uncertainty calibration plots for $E=5$. VI aggregation method comparison plot of test set accuracy vs. ratio of data retained for IID and non-IID 2-class data distributions. Non-IID Dirichlet distribution plots show trends and relationships that mirror IID distribution results. Normalized entropy and epistemic uncertainty plots are shown, aleatoric uncertainty plots show trends and relationships that mirror normalized entropy.
  • Figure 5: Learning curve plots for $E=5$. Plot of test set accuracy vs. federation round for VI aggregation methods for each data distribution.
  • ...and 4 more figures