Table of Contents
Fetching ...

A Multi-Token Coordinate Descent Method for Semi-Decentralized Vertical Federated Learning

Pedro Valdeira, Yuejie Chi, Cláudia Soares, João Xavier

TL;DR

This work tackles semi-decentralized vertical federated learning by introducing MTCD, a multi-token coordinate descent method that unifies client-server and decentralized schemes and enables a spectrum between them. The method extends STCD to multiple tokens that roam and periodically sync at a server, balancing parallel computation and information coupling to improve convergence and communication efficiency. Theoretical results establish $O(1/T)$ convergence for nonconvex objectives (and related rates under convexity) under large batch sizes, with analysis covering token overlap, clustering, and mini-batch variance. Empirically, MTCD outperforms fully decentralized and standard client-server baselines across convex problems and neural-network tasks, particularly when communication costs differ between client-client and client-server links, highlighting its practical impact for scalable VFL deployments.

Abstract

Most federated learning (FL) methods use a client-server scheme, where clients communicate only with a central server. However, this scheme is prone to bandwidth bottlenecks at the server and has a single point of failure. In contrast, in a (fully) decentralized approach, clients communicate directly with each other, dispensing with the server and mitigating these issues. Yet, as the client network grows larger and sparser, the convergence of decentralized methods slows down, even failing to converge if the network is disconnected. This work addresses this gap between client-server and decentralized schemes, focusing on the vertical FL setup, where clients hold different features of the same samples. We propose multi-token coordinate descent (MTCD), a flexible semi-decentralized method for vertical FL that can exploit both client-server and client-client links. By selecting appropriate hyperparameters, MTCD recovers the client-sever and decentralized schemes as special cases. In fact, its decentralized instance is itself a novel method of independent interest. Yet, by controlling the degree of dependency on client-server links, MTCD can also explore a spectrum of schemes ranging from client-server to decentralized. We prove that, for sufficiently large batch sizes, MTCD converges at an $\mathcal{O}(1/T)$ rate for nonconvex objectives when the tokens roam across disjoint subsets of clients. To capture the aforementioned drawbacks of the client-server scheme succinctly, we model the relative impact of using client-server versus client-client links as the ratio of their "costs", which depends on the application. This allows us to demonstrate, both analytically and empirically, that by tuning the degree of dependency on the server, the semi-decentralized instances of MTCD can outperform both client-server and decentralized approaches across a range of applications.

A Multi-Token Coordinate Descent Method for Semi-Decentralized Vertical Federated Learning

TL;DR

This work tackles semi-decentralized vertical federated learning by introducing MTCD, a multi-token coordinate descent method that unifies client-server and decentralized schemes and enables a spectrum between them. The method extends STCD to multiple tokens that roam and periodically sync at a server, balancing parallel computation and information coupling to improve convergence and communication efficiency. Theoretical results establish convergence for nonconvex objectives (and related rates under convexity) under large batch sizes, with analysis covering token overlap, clustering, and mini-batch variance. Empirically, MTCD outperforms fully decentralized and standard client-server baselines across convex problems and neural-network tasks, particularly when communication costs differ between client-client and client-server links, highlighting its practical impact for scalable VFL deployments.

Abstract

Most federated learning (FL) methods use a client-server scheme, where clients communicate only with a central server. However, this scheme is prone to bandwidth bottlenecks at the server and has a single point of failure. In contrast, in a (fully) decentralized approach, clients communicate directly with each other, dispensing with the server and mitigating these issues. Yet, as the client network grows larger and sparser, the convergence of decentralized methods slows down, even failing to converge if the network is disconnected. This work addresses this gap between client-server and decentralized schemes, focusing on the vertical FL setup, where clients hold different features of the same samples. We propose multi-token coordinate descent (MTCD), a flexible semi-decentralized method for vertical FL that can exploit both client-server and client-client links. By selecting appropriate hyperparameters, MTCD recovers the client-sever and decentralized schemes as special cases. In fact, its decentralized instance is itself a novel method of independent interest. Yet, by controlling the degree of dependency on client-server links, MTCD can also explore a spectrum of schemes ranging from client-server to decentralized. We prove that, for sufficiently large batch sizes, MTCD converges at an rate for nonconvex objectives when the tokens roam across disjoint subsets of clients. To capture the aforementioned drawbacks of the client-server scheme succinctly, we model the relative impact of using client-server versus client-client links as the ratio of their "costs", which depends on the application. This allows us to demonstrate, both analytically and empirically, that by tuning the degree of dependency on the server, the semi-decentralized instances of MTCD can outperform both client-server and decentralized approaches across a range of applications.
Paper Structure (42 sections, 4 theorems, 50 equations, 9 figures, 2 algorithms)

This paper contains 42 sections, 4 theorems, 50 equations, 9 figures, 2 algorithms.

Key Result

Lemma 1

Let $\{\bm{\theta}^t\}$ be a sequence generated by Algorithm alg:MTCD. If $f$ is $L$-smooth eq:L-smoothness and our mini-batch gradient estimate is unbiased eq:unbiased, we have that:

Figures (9)

  • Figure 1: On the left, we illustrate the semi-decentralized setup, where client-server communications are represented by dashed blue lines and client-client communications by solid green lines. On the right, we present a split neural network, where $K$ representations are obtained from neural networks, before an aggregation mechanism $H$ is applied and its result is inputted into a fusion neural network.
  • Figure 2: Ridge regression with $K=40$ clients, for six different network topologies. The MTCD run has $S\to\infty$ and $\Gamma=1$ (i.e. it corresponds to STCD). Algebraic connectivity, $\alpha_{\mathcal{G}}$, is the second smallest eigenvalue of the Laplacian matrix of a graph $\mathcal{G}$.
  • Figure 3: Sparse logistic regression on a Erdős–Rényi graph and a path graph, both with $K=40$ clients. The $S\to\infty$MTCD run has $\Gamma=1$ (i.e., it corresponds to STCD). Note that the flat DCPA trajectories with respect to communication cost have not plateaued, they simply take longer to see a drop in suboptimality.
  • Figure 4: We perform ridge regression on a path graph with $K=80$ nodes. We plot the suboptimality per iteration, the suboptimality per communication, and the number of communications needed to reach a given suboptimality for a range of ratios $R_C$. In the top row, MTCD with $S\to\infty$ has $\Gamma=1$ (that is, it corresponds to STCD) and MTCD with $S=64$ have $\Gamma=2$ tokens. In the bottom row, all instances of MTCD have $\Gamma=2$ tokens.
  • Figure 5: Ridge regression on a $K=40$ path graph, for a client–server communication cost of $20\times$, $10\times$, and $5\times$ larger than the client–client communication cost. (Single local update $Q=1$.)
  • ...and 4 more figures

Theorems & Definitions (4)

  • Lemma 1: Inner product upper bound
  • Lemma 2: Squared norm upper bound
  • Theorem 1: Main result---token-per-cluster setting
  • Theorem 2: Overlapping token trajectories setting