Table of Contents
Fetching ...

Learning from straggler clients in federated learning

Andrew Hard, Antonious M. Girgis, Ehsan Amid, Sean Augenstein, Lara McConnaughey, Rajiv Mathews, Rohan Anil

Abstract

How well do existing federated learning algorithms learn from client devices that return model updates with a significant time delay? Is it even possible to learn effectively from clients that report back minutes, hours, or days after being scheduled? We answer these questions by developing Monte Carlo simulations of client latency that are guided by real-world applications. We study synchronous optimization algorithms like FedAvg and FedAdam as well as the asynchronous FedBuff algorithm, and observe that all these existing approaches struggle to learn from severely delayed clients. To improve upon this situation, we experiment with modifications, including distillation regularization and exponential moving averages of model weights. Finally, we introduce two new algorithms, FARe-DUST and FeAST-on-MSG, based on distillation and averaging, respectively. Experiments with the EMNIST, CIFAR-100, and StackOverflow benchmark federated learning tasks demonstrate that our new algorithms outperform existing ones in terms of accuracy for straggler clients, while also providing better trade-offs between training time and total accuracy.

Learning from straggler clients in federated learning

Abstract

How well do existing federated learning algorithms learn from client devices that return model updates with a significant time delay? Is it even possible to learn effectively from clients that report back minutes, hours, or days after being scheduled? We answer these questions by developing Monte Carlo simulations of client latency that are guided by real-world applications. We study synchronous optimization algorithms like FedAvg and FedAdam as well as the asynchronous FedBuff algorithm, and observe that all these existing approaches struggle to learn from severely delayed clients. To improve upon this situation, we experiment with modifications, including distillation regularization and exponential moving averages of model weights. Finally, we introduce two new algorithms, FARe-DUST and FeAST-on-MSG, based on distillation and averaging, respectively. Experiments with the EMNIST, CIFAR-100, and StackOverflow benchmark federated learning tasks demonstrate that our new algorithms outperform existing ones in terms of accuracy for straggler clients, while also providing better trade-offs between training time and total accuracy.
Paper Structure (54 sections, 3 equations, 23 figures, 25 tables, 3 algorithms)

This paper contains 54 sections, 3 equations, 23 figures, 25 tables, 3 algorithms.

Figures (23)

  • Figure 1: The distribution of total client latencies, in seconds, sampled for the StackOverflow dataset with the per-example (Figure \ref{['fig:stackoverflow_latency_dist_both_a']}) and per-domain per-example client latency models (Figure \ref{['fig:stackoverflow_latency_dist_both_b']}). PDFs (top) are stacked, while the CDFs (bottom) are overlaid.
  • Figure 2: Straggler accuracy (left) and total accuracy (right) as a function of wall clock training time for EMNIST with the per-domain per-example latency model. Solid lines and bands represent the median and 90% confidence intervals from 10 trials.
  • Figure 3: Straggler accuracy as a function of total accuracy (left) and wall clock training time (right) for CIFAR-100 with per-domain per-example latency model. Plots show the median and 90% CI from 10 trials for the best straggler accuracy value in each trial.
  • Figure 4: A diagram of the three main components of FARe-DUST: stale gradient accumulation (Figure \ref{['fig:faredust_diagram_a']}), teacher network creation (Figure \ref{['fig:faredust_diagram_b']}), and stale teacher distillation (Figure \ref{['fig:faredust_diagram_c']}).
  • Figure 5: A visual diagram of the FeAST-on-MSG algorithm. The primary weight branch ($w_{t}, w_{t+1}, w_{t+2}, ...$) is advanced as usual for FedAvg. However, additional straggler updates are also allowed to accumulate for each checkpoint as they are received by the server $(\Delta_{t}^{+}, \Delta_{t+1}^{+}, \Delta_{t+2}^{+}, ...)$, up to a maximum wait time $\tau_{\textrm{max}}$. Once the maximum wait time has elapsed, an historical checkpoint is formed using the update: $w_{t+1}^{+} = w_t - (\frac{\eta_{g}}{B_{t}^{+}}) \Delta_{t}^{+}$. Historical checkpoints $w_{t+1}^{+}, w_{t+2}^{+}, w_{t+3}^{+}, ...$ are formed as the maximum wait times for each round elapse. An auxiliary branch $(a_{t+1}, a_{t+2}, a_{t+3}, ...)$ is advanced based on the historical checkpoints and stale gradients: $a_{t+1} = \beta (a_t - \frac{\eta_a}{B^+_t} \Delta^{+}_t) + (1 - \beta) w^+_{t+1}$. After a maximum of $T$ training rounds, $a_{T}$ is returned as the final model.
  • ...and 18 more figures