Table of Contents
Fetching ...

Improved Modelling of Federated Datasets using Mixtures-of-Dirichlet-Multinomials

Jonathan Scott, Áine Cahill

TL;DR

This work tackles the challenge of realistically simulating federated training by learning a distribution over per-client histograms that reflects non-IID heterogeneity. It introduces Mixture-of-Dirichlet-Multinomials (MDM), combining a Dirichlet-Multinomial with a mixture and an explicit model for the number of samples, and develops a fully federated EM-based maximum likelihood procedure to infer the mixture weights, Dirichlet parameters, and per-component sample-count distributions. The authors prove convergence of the EM updates and demonstrate that server-side simulations using the inferred MDM parameters more closely mirror true federated training than fully IID simulations, across datasets like FEMNIST, CIFAR10, and Folktables. The approach preserves privacy via aggregate statistics and is compatible with differential privacy, offering a practical path to faster, more reliable hyperparameter tuning and model selection in federated settings.

Abstract

In practice, training using federated learning can be orders of magnitude slower than standard centralized training. This severely limits the amount of experimentation and tuning that can be done, making it challenging to obtain good performance on a given task. Server-side proxy data can be used to run training simulations, for instance for hyperparameter tuning. This can greatly speed up the training pipeline by reducing the number of tuning runs to be performed overall on the true clients. However, it is challenging to ensure that these simulations accurately reflect the dynamics of the real federated training. In particular, the proxy data used for simulations often comes as a single centralized dataset without a partition into distinct clients, and partitioning this data in a naive way can lead to simulations that poorly reflect real federated training. In this paper we address the challenge of how to partition centralized data in a way that reflects the statistical heterogeneity of the true federated clients. We propose a fully federated, theoretically justified, algorithm that efficiently learns the distribution of the true clients and observe improved server-side simulations when using the inferred distribution to create simulated clients from the centralized data.

Improved Modelling of Federated Datasets using Mixtures-of-Dirichlet-Multinomials

TL;DR

This work tackles the challenge of realistically simulating federated training by learning a distribution over per-client histograms that reflects non-IID heterogeneity. It introduces Mixture-of-Dirichlet-Multinomials (MDM), combining a Dirichlet-Multinomial with a mixture and an explicit model for the number of samples, and develops a fully federated EM-based maximum likelihood procedure to infer the mixture weights, Dirichlet parameters, and per-component sample-count distributions. The authors prove convergence of the EM updates and demonstrate that server-side simulations using the inferred MDM parameters more closely mirror true federated training than fully IID simulations, across datasets like FEMNIST, CIFAR10, and Folktables. The approach preserves privacy via aggregate statistics and is compatible with differential privacy, offering a practical path to faster, more reliable hyperparameter tuning and model selection in federated settings.

Abstract

In practice, training using federated learning can be orders of magnitude slower than standard centralized training. This severely limits the amount of experimentation and tuning that can be done, making it challenging to obtain good performance on a given task. Server-side proxy data can be used to run training simulations, for instance for hyperparameter tuning. This can greatly speed up the training pipeline by reducing the number of tuning runs to be performed overall on the true clients. However, it is challenging to ensure that these simulations accurately reflect the dynamics of the real federated training. In particular, the proxy data used for simulations often comes as a single centralized dataset without a partition into distinct clients, and partitioning this data in a naive way can lead to simulations that poorly reflect real federated training. In this paper we address the challenge of how to partition centralized data in a way that reflects the statistical heterogeneity of the true federated clients. We propose a fully federated, theoretically justified, algorithm that efficiently learns the distribution of the true clients and observe improved server-side simulations when using the inferred distribution to create simulated clients from the centralized data.
Paper Structure (41 sections, 1 theorem, 38 equations, 22 figures, 1 table, 2 algorithms)

This paper contains 41 sections, 1 theorem, 38 equations, 22 figures, 1 table, 2 algorithms.

Key Result

Theorem 3.1

Let $\left(\boldsymbol{\mathbf{c}}_i, n_i\right)_{i=1}^M$ be observed histogram, sample count data and $(\boldsymbol{\mathbf{\tau}}^{(0)}$, $\boldsymbol{\mathbf{\Pi}}^{(0)}$, $\boldsymbol{\mathbf{A}}^{(0)})$ be an initialization of the parameters of the Mixture-of-Dirichlet-Multinomials model (eq:MD where $Z_i$ is the latent variable indicating the mixture component that the $i$th sample was drawn

Figures (22)

  • Figure 1: Proposed approach to server-side simulations. From left to right: learn Mixture-of-Dirichlet-Multinomials distribution from true federated clients (in this case 2 mixture components); use learned distribution to partition server proxy data into simulated clients; run server-side simulated federated model training using the simulated clients.
  • Figure 2: Normalized mean squared error (MSE) between the ground truth distribution parameter value and the inferred parameter value over time. Ground truth corresponds to medium levels of client statistical heterogeneity. On the left for $\boldsymbol{\mathbf{A}}$, on the right for $\boldsymbol{\mathbf{\tau}}$.
  • Figure 3: t-SNE visualization of FEMNIST clients, each point corresponds to a single client's class histogram. True clients (green), fully IID simulated clients (blue) and MDM clients (red).
  • Figure 4: t-SNE visualizations of Folktables clients, each point corresponds to a single client’s histogram over the race feature (left) and over the binned income feature (right). True clients (green), fully IID simulated clients (blue) and MDM clients (red). Inferred in both cases using $K=7$.
  • Figure 5: FEMNIST test accuracy when training with FedAvg for different settings of local learning rate and local epochs. True clients (dotted green), conditionally IID simulated clients (green), learned MDM simulated clients (red) and fully IID simulated clients (blue).
  • ...and 17 more figures

Theorems & Definitions (1)

  • Theorem 3.1