Table of Contents
Fetching ...

Federated Learning with Neural Graphical Models

Urszula Chajewska, Harsh Shrivastava

TL;DR

A FL framework which maintains a global NGM model that learns the averaged information from the local NGM models while keeping the training data within the client's environment, and proposes a `Stitching' algorithm, which personalizes the global NGM models by merging the additional variables using the client's data.

Abstract

Federated Learning (FL) addresses the need to create models based on proprietary data in such a way that multiple clients retain exclusive control over their data, while all benefit from improved model accuracy due to pooled resources. Recently proposed Neural Graphical Models (NGMs) are Probabilistic Graphical models that utilize the expressive power of neural networks to learn complex non-linear dependencies between the input features. They learn to capture the underlying data distribution and have efficient algorithms for inference and sampling. We develop a FL framework which maintains a global NGM model that learns the averaged information from the local NGM models while keeping the training data within the client's environment. Our design, FedNGMs, avoids the pitfalls and shortcomings of neuron matching frameworks like Federated Matched Averaging that suffers from model parameter explosion. Our global model size remains constant throughout the process. In the cases where clients have local variables that are not part of the combined global distribution, we propose a `Stitching' algorithm, which personalizes the global NGM models by merging the additional variables using the client's data. FedNGM is robust to data heterogeneity, large number of participants, and limited communication bandwidth. We experimentally demonstrated the use of FedNGMs for extracting insights from CDC's Infant Mortality dataset and discuss interesting future applications.

Federated Learning with Neural Graphical Models

TL;DR

A FL framework which maintains a global NGM model that learns the averaged information from the local NGM models while keeping the training data within the client's environment, and proposes a `Stitching' algorithm, which personalizes the global NGM models by merging the additional variables using the client's data.

Abstract

Federated Learning (FL) addresses the need to create models based on proprietary data in such a way that multiple clients retain exclusive control over their data, while all benefit from improved model accuracy due to pooled resources. Recently proposed Neural Graphical Models (NGMs) are Probabilistic Graphical models that utilize the expressive power of neural networks to learn complex non-linear dependencies between the input features. They learn to capture the underlying data distribution and have efficient algorithms for inference and sampling. We develop a FL framework which maintains a global NGM model that learns the averaged information from the local NGM models while keeping the training data within the client's environment. Our design, FedNGMs, avoids the pitfalls and shortcomings of neuron matching frameworks like Federated Matched Averaging that suffers from model parameter explosion. Our global model size remains constant throughout the process. In the cases where clients have local variables that are not part of the combined global distribution, we propose a `Stitching' algorithm, which personalizes the global NGM models by merging the additional variables using the client's data. FedNGM is robust to data heterogeneity, large number of participants, and limited communication bandwidth. We experimentally demonstrated the use of FedNGMs for extracting insights from CDC's Infant Mortality dataset and discuss interesting future applications.
Paper Structure (14 sections, 5 equations, 8 figures, 4 tables)

This paper contains 14 sections, 5 equations, 8 figures, 4 tables.

Figures (8)

  • Figure 1: Graph Recovery approaches. Methods that recover graphs are categorized. FedNGM typically utilizes undirected graph recovery methods like Neural Graph Revealers(NGR) shrivastava2023NGR or the uGLADshrivastava2022ugladshrivastava2022a algorithm. One can use the methods that retrieve directed graphs by adding a post-processing step of moralization which converts them into their undirected counterparts. The algorithms (leaf nodes) listed here are representative of the sub-category and the list is not exhaustive. Figure borrowed from shrivastava2023methods.
  • Figure 2: Obtaining global consensus graph. In order to set the stage for training individual client NGMs, we first obtain the sparse graph that captures the dependency structure among the input features for the entire domain. [1] Each client $\mathcal{C}_c$ runs a graph recovery algorithm (e.g., expert provided, obtained by uGLAD, etc.) on their private data $X_c$, to obtain a dependency graph with the adjacency matrix $S_c\in\{0,1\}^{F_c\times F_c}$. [2] All the clients send their graphs $S_c$ to the master server. [3] The master server merges all the dependency graphs by only considering the common features across clients $F_g = \bigcap_{c=1}^C F_c$ and union of all the edges among these common features to obtain the global graph $\mathcal{G}_g$. Please observe the features $x_1, x_2, x_3$ and their connection updates in the master. [4] The master sends the graph $\mathcal{G}_g$ to the clients and then the local and global NGM models are initialized. One can optionally include public data in this framework.
  • Figure 3: Training the global NGM$_\mathcal{G}$ model. All the clients and the master have a copy of the global consensus graph $\mathcal{G}_g$. [1] Each client trains the NGM model based on their data, which is modified to only contain the common features, using the objective in Eq. \ref{['eqn:optimization-local']}. [2] All the clients send their trained NGM models to the master. Note that their data remain private. [3] The master trains the global NGM model that learns the average of the client NGMs using the objective in Eq. \ref{['eqn:optimization-global']}. Optionally, one can use additional public data and samples from client NGMs, refer to the inclusion of regression term in Eq. \ref{['eqn:optimization-global-reg']}, while training global NGM for desired results. [4] The global NGM$_\mathcal{G}$ model is then transferred to all the clients, where they can run the personalized FL to customize their model using their proprietary data.
  • Figure 4: Personalized FL Stitching algorithm. Each client receives the trained global model NGM$_\mathcal{G}$ from the master. The highlighted nodes are the additional nodes introduced in the local model. The nodes in orange are the client specific features, while the dark green are the new hidden units introduced to facilitate capturing of dependencies between the common features and the newly added features. Only the new edges ( orange arrows and the thick green arrows which represent connections to all nodes from the new hidden unit) introduced by the additional nodes are learned from the client's data. One can potentially increase the number of the hidden units for desired results.
  • Figure 5: Training the global NGR model. [1] Each client trains the NGR model based on their data, using the objective in Eq. \ref{['eqn:optimization-function-ngr']}. [2] All the clients send their trained NGR models to the master. Note that their data remain private. [3] The master trains the global NGR model that learns the average of the client NGRs. [4] The global NGR model is then transferred to all the clients, where they can run the personalized FL to customize their model using their proprietary data.
  • ...and 3 more figures