Table of Contents
Fetching ...

FedBEns: One-Shot Federated Learning based on Bayesian Ensemble

Jacopo Talpini, Marco Savi, Giovanni Neglia

TL;DR

This paper analyzes the One-Shot FL problem through the lens of Bayesian inference and proposes FedBEns, an algorithm that leverages the inherent multimodality of local loss functions to find better global models.

Abstract

One-Shot Federated Learning (FL) is a recent paradigm that enables multiple clients to cooperatively learn a global model in a single round of communication with a central server. In this paper, we analyze the One-Shot FL problem through the lens of Bayesian inference and propose FedBEns, an algorithm that leverages the inherent multimodality of local loss functions to find better global models. Our algorithm leverages a mixture of Laplace approximations for the clients' local posteriors, which the server then aggregates to infer the global model. We conduct extensive experiments on various datasets, demonstrating that the proposed method outperforms competing baselines that typically rely on unimodal approximations of the local losses.

FedBEns: One-Shot Federated Learning based on Bayesian Ensemble

TL;DR

This paper analyzes the One-Shot FL problem through the lens of Bayesian inference and proposes FedBEns, an algorithm that leverages the inherent multimodality of local loss functions to find better global models.

Abstract

One-Shot Federated Learning (FL) is a recent paradigm that enables multiple clients to cooperatively learn a global model in a single round of communication with a central server. In this paper, we analyze the One-Shot FL problem through the lens of Bayesian inference and propose FedBEns, an algorithm that leverages the inherent multimodality of local loss functions to find better global models. Our algorithm leverages a mixture of Laplace approximations for the clients' local posteriors, which the server then aggregates to infer the global model. We conduct extensive experiments on various datasets, demonstrating that the proposed method outperforms competing baselines that typically rely on unimodal approximations of the local losses.

Paper Structure

This paper contains 27 sections, 2 theorems, 11 equations, 5 figures, 5 tables, 1 algorithm.

Key Result

Proposition 3.1

Given $C$ datasets and the posteriors $p(\mathbf{w}|\mathcal{D}_c)$ of the same model $f(\mathbf{w})$, under the assumptions that (i) the datasets are conditionally independent given the model, and (ii) the prior is the same across clients, the global posterior can be written as:

Figures (5)

  • Figure 1: An illustration of One-Shot FL in a toy 2D setting with two clients: one with a unimodal loss function (Client 2, left plot) and the other with a bimodal loss function (Client 1, left plot), along with the overall 'ground-truth' global loss (right plot). The left plot also shows several stochastic gradient descent (SGD) trajectories on the two losses (dotted lines), to mimic clients' training, starting from different random initializations. In the right plot, the red dots represent the reconstructed global optimum, depending on whether client 1 approximates its local loss as quadratic around its global optimum (a more likely result, represented by the bigger dot) or around the secondary minimum. The yellow star denotes the global optimum inferred by our approach, based on an ensemble of all the optimal solutions found by each SGD run. To calculate the global loss, the secondary minimum of Client 1 is more relevant than its global one.
  • Figure 2: Test accuracy as a function of the number of mixtures for FedBEns with Kronecker factorization. For each dataset, the mean accuracy is represented by a solid line, while the shaded band indicates the standard deviation, both computed over 5 seeds, for various heterogeneity parameters.
  • Figure 3: Test accuracy as a function of the number of mixtures with different Hessian approximations. Curves represent the mean accuracy computed over 3 seeds. Standard deviation bands are not reported for better clarity.
  • Figure 4: Test accuracy for FedBEns with Kronecker factorization and 5 mixtures, as a function of the temperature parameter. Curves represent the mean accuracy computed over 3 seeds.
  • Figure 5: Test accuracy for FedBEns with Kronecker factorization and 5 mixtures, as a function of the prior variance. Curves represent the mean accuracy computed over 3 seeds.

Theorems & Definitions (2)

  • Proposition 3.1
  • Proposition 1.1