Table of Contents
Fetching ...

Variational Federated Multi-Task Learning

Luca Corinzia, Ami Beuret, Joachim M. Buhmann

TL;DR

This work tackles federated learning under strong data heterogeneity by reframing the problem as federated multi-task learning for non-convex models. The authors propose VIRTUAL, a variational inference-based framework that models a star-shaped Bayesian network connecting a central server to multiple clients, yielding shared and task-specific parameters with privacy-preserving updates. Empirical results on diverse real-world datasets show VIRTUAL outperforms FedAvg and FedProx on MT performance while enabling sparser updates and reduced communication. The approach provides a principled way to transfer knowledge across clients and personalize models to local distributions, with practical implications for scalable, privacy-preserving FL.

Abstract

In federated learning, a central server coordinates the training of a single model on a massively distributed network of devices. This setting can be naturally extended to a multi-task learning framework, to handle real-world federated datasets that typically show strong statistical heterogeneity among devices. Despite federated multi-task learning being shown to be an effective paradigm for real-world datasets, it has been applied only on convex models. In this work, we introduce VIRTUAL, an algorithm for federated multi-task learning for general non-convex models. In VIRTUAL the federated network of the server and the clients is treated as a star-shaped Bayesian network, and learning is performed on the network using approximated variational inference. We show that this method is effective on real-world federated datasets, outperforming the current state-of-the-art for federated learning, and concurrently allowing sparser gradient updates.

Variational Federated Multi-Task Learning

TL;DR

This work tackles federated learning under strong data heterogeneity by reframing the problem as federated multi-task learning for non-convex models. The authors propose VIRTUAL, a variational inference-based framework that models a star-shaped Bayesian network connecting a central server to multiple clients, yielding shared and task-specific parameters with privacy-preserving updates. Empirical results on diverse real-world datasets show VIRTUAL outperforms FedAvg and FedProx on MT performance while enabling sparser updates and reduced communication. The approach provides a principled way to transfer knowledge across clients and personalize models to local distributions, with practical implications for scalable, privacy-preserving FL.

Abstract

In federated learning, a central server coordinates the training of a single model on a massively distributed network of devices. This setting can be naturally extended to a multi-task learning framework, to handle real-world federated datasets that typically show strong statistical heterogeneity among devices. Despite federated multi-task learning being shown to be an effective paradigm for real-world datasets, it has been applied only on convex models. In this work, we introduce VIRTUAL, an algorithm for federated multi-task learning for general non-convex models. In VIRTUAL the federated network of the server and the clients is treated as a star-shaped Bayesian network, and learning is performed on the network using approximated variational inference. We show that this method is effective on real-world federated datasets, outperforming the current state-of-the-art for federated learning, and concurrently allowing sparser gradient updates.

Paper Structure

This paper contains 20 sections, 1 theorem, 7 equations, 6 figures, 4 tables, 1 algorithm.

Key Result

Proposition 1

Assuming that at step $t$ the factor $i$ is refined, then the proxy pdf $s_i^{(t)}(\bm{\theta})$ and $c^{(t)}_i(\bm{\phi}_i)$ are found minimizing the variational free energy function $\mathcal{L}_i \coloneqq \mathcal{L}(s_i(\bm{\theta}),c_i(\bm{\phi}_i))$, with where $s^{(t)}(\bm{\theta}) = s_{i}(\bm{\theta})\prod_{j \ne i}^K s_j^{(t-1)}(\bm{\theta})$ is the updated posterior over the server par

Figures (6)

  • Figure 1: Graphical models that describe the VIRTUAL framework for federated learning. The plates represent replicates. In both figures, the outer plate replicates client $i$ over the total number of clients $K$, while the inner plate replicates sample index $n$ over the total number of samples per client $N_i$. Shadowed nodes represent observed variables and non-shadowed nodes represent latent variables. (a) Solid lines denote the discriminative model $p(y_i^{(n)}| \bm{x}_i^{(n)}, \bm{\theta}, \bm{\phi}_i)$. (b) Graphical model of the approximated variational posterior. Dashed lines denote (deterministic) dependencies in the approximated variational posterior while dotted lines denote stochastic dependencies. Here we indicate as $(\bm{\mu}^s$,$\bm{\sigma}^s)$ and $(\bm{\mu}^c_i$,$\bm{\sigma}^c_i)$ the collection of all Gaussian parameters of server and client $i$.
  • Figure 2: Regularization effect of the KL divergence on the cross-entropy loss. For the FEMNIST dataset, we report the server and the MT cross-entropy loss during training at different values of the KL divergence multiplier $\beta$. Thick lines represent moving averages with window size 20. Log-scale is applied on both the y scale and the color bar.
  • Figure 3: Central server (first row) and multi-task (MT, second row) cross-entropy loss as a function of the federated training round. For every dataset (column) we report the loss of the two baselines FedAvg and FedProx and our method Virtual. For the FEMNIST dataset, we report both the performance of the MLP and the convolutional NN architecture. For all the other datasets, only the performance of one model is showed (LSTM architecture in the Shakespeare datasets, MLPs on all the others). Log-scale is used in the y-axis.
  • Figure 4: Cumulative distribution function of the signal-to-noise ratio (reported in log scale) for all clients and for the three consecutive dense layer of the network. Clients that are not initialized with the server weights at each round show a more compressible model. Simulation performed on the FEMNIST dataset on an MLP architecture.
  • Figure 5: Additional learning curves.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Proposition 1
  • proof : Proof of \ref{['prop:ep_free_energy']}