Table of Contents
Fetching ...

Asynchronous Heavy-Tailed Optimization

Junfei Sun, Dixi Yao, Xuchen Gong, Tahseen Rabbani, Manzil Zaheer, Tian Li

TL;DR

This work investigates two communication schemes that handle stragglers with asynchronous updates in the presence of heavy-tailed gradient noise and proposes and theoretically analyze algorithmic modifications based on delay-aware learning rate scheduling and delay compensation to enhance the performance of asynchronous algorithms.

Abstract

Heavy-tailed stochastic gradient noise, commonly observed in transformer models, can destabilize the optimization process. Recent works mainly focus on developing and understanding approaches to address heavy-tailed noise in the centralized or distributed, synchronous setting, leaving the interactions between such noise and asynchronous optimization underexplored. In this work, we investigate two communication schemes that handle stragglers with asynchronous updates in the presence of heavy-tailed gradient noise. We propose and theoretically analyze algorithmic modifications based on delay-aware learning rate scheduling and delay compensation to enhance the performance of asynchronous algorithms. Our convergence guarantees under heavy-tailed noise match the rate of the synchronous counterparts and improve delay tolerance compared with existing asynchronous approaches. Empirically, our approaches outperform prior synchronous and asynchronous methods in terms of accuracy/runtime trade-offs and are more robust to hyperparameters in both image and language tasks.

Asynchronous Heavy-Tailed Optimization

TL;DR

This work investigates two communication schemes that handle stragglers with asynchronous updates in the presence of heavy-tailed gradient noise and proposes and theoretically analyze algorithmic modifications based on delay-aware learning rate scheduling and delay compensation to enhance the performance of asynchronous algorithms.

Abstract

Heavy-tailed stochastic gradient noise, commonly observed in transformer models, can destabilize the optimization process. Recent works mainly focus on developing and understanding approaches to address heavy-tailed noise in the centralized or distributed, synchronous setting, leaving the interactions between such noise and asynchronous optimization underexplored. In this work, we investigate two communication schemes that handle stragglers with asynchronous updates in the presence of heavy-tailed gradient noise. We propose and theoretically analyze algorithmic modifications based on delay-aware learning rate scheduling and delay compensation to enhance the performance of asynchronous algorithms. Our convergence guarantees under heavy-tailed noise match the rate of the synchronous counterparts and improve delay tolerance compared with existing asynchronous approaches. Empirically, our approaches outperform prior synchronous and asynchronous methods in terms of accuracy/runtime trade-offs and are more robust to hyperparameters in both image and language tasks.
Paper Structure (39 sections, 16 theorems, 115 equations, 5 figures, 17 tables, 2 algorithms)

This paper contains 39 sections, 16 theorems, 115 equations, 5 figures, 17 tables, 2 algorithms.

Key Result

Theorem 1

Assume $F(\cdot)$ being $L$-smooth and $G$-Lipschitz. Let $u$ denote the client-side gradient clipping threshold and set it to $u=\Theta(T^{\zeta})$, and let $p_{t,j}$ be the delay of the updates received at global round $t$ for client $j$. If the stochastic gradient noise is heavy-tailed, i.e., it In particular, if all $p_{t,j}'s$ take the same value $p$, and $p=\Theta(T^b)$ with $b\leq \frac{1}

Figures (5)

  • Figure 1: Illustrations of server- and client-centric asynchronous models where the total number of workers $N=3$, number of local steps $K=3$, and asynchronous buffer size $M=2$.
  • Figure 2: Test loss versus #epochs (left) and runtime (right) of synchronous SGDClip/$Clip^2$ and asynchronous SGDClip/$Clip^2$ with large stragglers. We observe that asynchronous methods provide similar loss as synchronous versions with significantly less runtime, across different optimizers including $Clip^2$ which is designed to handle heavy-tailed noise. Moreover, comparing 'CC Async SGDClip' with 'CC Async Clip$^2$' (or 'SC Async SGDClip' with 'SC Async Clip$^2$'), we see that $Clip^2$ has inherent benefits to control the bias caused by asynchrony. We observe similar trends for the mild straggler setting (Figure \ref{['fig:normal_async']} in the appendix).
  • Figure 3: Test loss versus #epochs and runtime of Sync SGDClip, Async SGDClip and $Clip^2$ under a specific hyperparameter choice on CIFAR-10. 'cr, sr, cu, su' denotes client learning rate, server learning rate, client upper clipping bound, and server upper clipping bound, respectively. We see that vanilla asynchronous could cause divergence, whereas $Clip^2$, which is originally designed to handle heavy-tailed noise, can help convergence.
  • Figure 4: Test loss versus #epochs of vanilla asynchronous method and that with staleness-aware downplaying (SD) under SGDClip optimizer for three hyperparameter choices. SD largely improves the robustness of hyperparameter choosing.
  • Figure 5: Best Test Loss v.s. Epochs and Runtime of Sync SGDClip/$Clip^2$ and Async SGDClip/$Clip^2$ under the mild straggler setting on CIFAR-10.

Theorems & Definitions (27)

  • Theorem 1
  • Remark 1
  • Proposition 1
  • Theorem 2
  • Proposition 2
  • Theorem 3
  • Remark 2
  • Theorem 4
  • proof
  • Theorem 5
  • ...and 17 more