Table of Contents
Fetching ...

Federated Automatic Differentiation

Keith Rush, Zachary Charles, Zachary Garrett

TL;DR

This work shows how FAD can be used to create algorithms that dynamically learn components of the algorithm itself, and shows that FedAvg-style algorithms can exhibit significantly improved performance by using FAD to adjust the server optimization step automatically, or by using FAD to learn weighting schemes for computing weighted averages across clients.

Abstract

Federated learning (FL) is a general framework for learning across an axis of group partitioned data (heterogeneous clients) while preserving data privacy, under the orchestration of a central server. FL methods often compute gradients of loss functions purely locally (ie. entirely at each client, or entirely at the server), typically using automatic differentiation (AD) techniques. We propose a federated automatic differentiation (FAD) framework that 1) enables computing derivatives of functions involving client and server computation as well as communication between them and 2) operates in a manner compatible with existing federated technology. In other words, FAD computes derivatives across communication boundaries. We show, in analogy with traditional AD, that FAD may be implemented using various accumulation modes, which introduce distinct computation-communication trade-offs and systems requirements. Further, we show that a broad class of federated computations is closed under these various modes of FAD, implying in particular that if the original computation can be implemented using privacy-preserving primitives, its derivative may be computed using only these same primitives. We then show how FAD can be used to create algorithms that dynamically learn components of the algorithm itself. In particular, we show that FedAvg-style algorithms can exhibit significantly improved performance by using FAD to adjust the server optimization step automatically, or by using FAD to learn weighting schemes for computing weighted averages across clients.

Federated Automatic Differentiation

TL;DR

This work shows how FAD can be used to create algorithms that dynamically learn components of the algorithm itself, and shows that FedAvg-style algorithms can exhibit significantly improved performance by using FAD to adjust the server optimization step automatically, or by using FAD to learn weighting schemes for computing weighted averages across clients.

Abstract

Federated learning (FL) is a general framework for learning across an axis of group partitioned data (heterogeneous clients) while preserving data privacy, under the orchestration of a central server. FL methods often compute gradients of loss functions purely locally (ie. entirely at each client, or entirely at the server), typically using automatic differentiation (AD) techniques. We propose a federated automatic differentiation (FAD) framework that 1) enables computing derivatives of functions involving client and server computation as well as communication between them and 2) operates in a manner compatible with existing federated technology. In other words, FAD computes derivatives across communication boundaries. We show, in analogy with traditional AD, that FAD may be implemented using various accumulation modes, which introduce distinct computation-communication trade-offs and systems requirements. Further, we show that a broad class of federated computations is closed under these various modes of FAD, implying in particular that if the original computation can be implemented using privacy-preserving primitives, its derivative may be computed using only these same primitives. We then show how FAD can be used to create algorithms that dynamically learn components of the algorithm itself. In particular, we show that FedAvg-style algorithms can exhibit significantly improved performance by using FAD to adjust the server optimization step automatically, or by using FAD to learn weighting schemes for computing weighted averages across clients.
Paper Structure (30 sections, 16 equations, 13 figures, 4 tables, 2 algorithms)

This paper contains 30 sections, 16 equations, 13 figures, 4 tables, 2 algorithms.

Figures (13)

  • Figure 1: Computational graph of $f(x_1, x_2) = \sin(x_1) + x_1x_2$. Here, $v_i$ is an intermediate value of the computation as given in \ref{['eq:basic_function']}.
  • Figure 2: Federated computational graph for $y = f(x)$ as in \ref{['example:basic_fed_comp']}, where three clients participate. Here, $v_i$ represent intermediate values used in the computation. Each value has a placement (server or clients), and the clients have data $z_i$ which are not shared with one another or the server. Server$\to$client communication is done via federated_ broadcast, and client$\to$server communication is done via federated_ sum.
  • Figure 3: Federated computational graph for a single round of FedOpt.
  • Figure 4: Federated computational graph used to compute server hypergradients in FedOpt. All server$\to$client communication is done via federated_ broadcast, while all client$\to$server is done via federated_ mean. This computation produces some updated model and an estimate of the loss of that model.
  • Figure 5: Federated computational graph used to compute hypergradients of server hyperparameters in FedOpt. All server$\to$client communication is done via federated_ broadcast, while all client$\to$server is done via federated_ mean. This computation produces some updated model and an estimate of the loss of that model. In contrast to \ref{['fig:server_hparam_serial']}, model training and loss computation (for the purposes of computing hypergradients) are done in parallel, potentially across different sets of clients. Note that applying federated AD to compute the derivative of the average loss with respect to the hyperparameter requires chaining two of these graphs together (in order to create a path from "hparam" to "average loss"), though this can be performed by the server after the fact leveraging only a direct application of federated AD to the graph presented here.
  • ...and 8 more figures