Table of Contents
Fetching ...

Federated Learning for Non-factorizable Models using Deep Generative Prior Approximations

Conor Hassan, Joshua J Bon, Elizaveta Semenova, Antonietta Mira, Kerrie Mengersen

TL;DR

This work tackles the restriction of current federated learning (FL) methods that assume conditional independence across clients by introducing Structured Independence via deep Generative Model Approximations (SIGMA) priors. SIGMA learns a hierarchical latent representation with a global latent shared across clients and local latents per client to capture dependencies while preserving a tractable structure for FL updates, enabling the use of established FL algorithms such as Structured Federated Variational Inference. The authors demonstrate SIGMA on synthetic 1D GP regression and a real-world-inspired spatial CAR-prior problem over Australia, showing that SIGMA can approximate complex priors and provide posterior inferences close to non-FL benchmarks, with an auxiliary-variable extension improving calibration and reducing overfitting. The approach broadens the applicability of FL to domains where dependencies across geographically or otherwise linked clients are essential, and suggests future work including more expressive generative models and uncertainty-aware variants.

Abstract

Federated learning (FL) allows for collaborative model training across decentralized clients while preserving privacy by avoiding data sharing. However, current FL methods assume conditional independence between client models, limiting the use of priors that capture dependence, such as Gaussian processes (GPs). We introduce the Structured Independence via deep Generative Model Approximation (SIGMA) prior which enables FL for non-factorizable models across clients, expanding the applicability of FL to fields such as spatial statistics, epidemiology, environmental science, and other domains where modeling dependencies is crucial. The SIGMA prior is a pre-trained deep generative model that approximates the desired prior and induces a specified conditional independence structure in the latent variables, creating an approximate model suitable for FL settings. We demonstrate the SIGMA prior's effectiveness on synthetic data and showcase its utility in a real-world example of FL for spatial data, using a conditional autoregressive prior to model spatial dependence across Australia. Our work enables new FL applications in domains where modeling dependent data is essential for accurate predictions and decision-making.

Federated Learning for Non-factorizable Models using Deep Generative Prior Approximations

TL;DR

This work tackles the restriction of current federated learning (FL) methods that assume conditional independence across clients by introducing Structured Independence via deep Generative Model Approximations (SIGMA) priors. SIGMA learns a hierarchical latent representation with a global latent shared across clients and local latents per client to capture dependencies while preserving a tractable structure for FL updates, enabling the use of established FL algorithms such as Structured Federated Variational Inference. The authors demonstrate SIGMA on synthetic 1D GP regression and a real-world-inspired spatial CAR-prior problem over Australia, showing that SIGMA can approximate complex priors and provide posterior inferences close to non-FL benchmarks, with an auxiliary-variable extension improving calibration and reducing overfitting. The approach broadens the applicability of FL to domains where dependencies across geographically or otherwise linked clients are essential, and suggests future work including more expressive generative models and uncertainty-aware variants.

Abstract

Federated learning (FL) allows for collaborative model training across decentralized clients while preserving privacy by avoiding data sharing. However, current FL methods assume conditional independence between client models, limiting the use of priors that capture dependence, such as Gaussian processes (GPs). We introduce the Structured Independence via deep Generative Model Approximation (SIGMA) prior which enables FL for non-factorizable models across clients, expanding the applicability of FL to fields such as spatial statistics, epidemiology, environmental science, and other domains where modeling dependencies is crucial. The SIGMA prior is a pre-trained deep generative model that approximates the desired prior and induces a specified conditional independence structure in the latent variables, creating an approximate model suitable for FL settings. We demonstrate the SIGMA prior's effectiveness on synthetic data and showcase its utility in a real-world example of FL for spatial data, using a conditional autoregressive prior to model spatial dependence across Australia. Our work enables new FL applications in domains where modeling dependent data is essential for accurate predictions and decision-making.
Paper Structure (18 sections, 21 equations, 7 figures, 2 tables)

This paper contains 18 sections, 21 equations, 7 figures, 2 tables.

Figures (7)

  • Figure 1: The standard setting on the left shows the dependency structure for a conditional autoregressive (CAR) prior in a non-FL setting. On the right, the bold borders represent boundaries for different clients, and the red grids represent values we cannot evaluate when trying to infer, in an FL setting, the value at the dark-green grids.
  • Figure 2: Graphical description of the generative model fit to input prior draws $\boldsymbol{\theta}$ to create the SIGMA prior. The output is $\hat{\boldsymbol{\theta}}=(\hat{\boldsymbol{\theta}}_1^\top, \ldots, \hat{\boldsymbol{\theta}}_J^\top)^\top$, where each of the $J$ blocks of parameters $\hat{\boldsymbol{\theta}}_{j}$ is conditionally independent given the global latent variable $\boldsymbol{z}_G$ and the deterministic output $\boldsymbol{h}$ of a neural network $f_{\boldsymbol{\psi}_G}$, the global decoder, that takes $\boldsymbol{z}_G$ and $\phi$ as input. The SIGMA prior is trained as a hierarchical variational autoencoder with a specific structure, such that the local components of the encoders and decoders factorize over clients.
  • Figure 3: Comparison of the empirical covariance of the draws from the RBF kernel (top row) and the learned SIGMA prior (bottom row) for $\phi=(0.2, 0.5, 0.8)$. The SIGMA prior closely approximates the true covariance structure, capturing client dependencies.
  • Figure 4: $25$ draws from the fitted SIGMA prior. The red-vertical-dashed lines denote boundaries between clients, and the colors of each line denote the length scale of the Gaussian process kernel that the particular line is approximating.
  • Figure 5: Mean estimate (solid blue line) and 90% credible interval (shaded region) for the mean function using the SIGMA approximation. The posterior was estimated using the NUTS algorithm. The red dashed lines separate the different clients.
  • ...and 2 more figures