Variational Federated Multi-Task Learning
Luca Corinzia, Ami Beuret, Joachim M. Buhmann
TL;DR
This work tackles federated learning under strong data heterogeneity by reframing the problem as federated multi-task learning for non-convex models. The authors propose VIRTUAL, a variational inference-based framework that models a star-shaped Bayesian network connecting a central server to multiple clients, yielding shared and task-specific parameters with privacy-preserving updates. Empirical results on diverse real-world datasets show VIRTUAL outperforms FedAvg and FedProx on MT performance while enabling sparser updates and reduced communication. The approach provides a principled way to transfer knowledge across clients and personalize models to local distributions, with practical implications for scalable, privacy-preserving FL.
Abstract
In federated learning, a central server coordinates the training of a single model on a massively distributed network of devices. This setting can be naturally extended to a multi-task learning framework, to handle real-world federated datasets that typically show strong statistical heterogeneity among devices. Despite federated multi-task learning being shown to be an effective paradigm for real-world datasets, it has been applied only on convex models. In this work, we introduce VIRTUAL, an algorithm for federated multi-task learning for general non-convex models. In VIRTUAL the federated network of the server and the clients is treated as a star-shaped Bayesian network, and learning is performed on the network using approximated variational inference. We show that this method is effective on real-world federated datasets, outperforming the current state-of-the-art for federated learning, and concurrently allowing sparser gradient updates.
