Table of Contents
Fetching ...

Distributed Sign Momentum with Local Steps for Training Transformers

Shuhua Yu, Ding Zhou, Cong Xie, An Xu, Zhi Zhang, Xin Liu, Soummya Kar

TL;DR

This paper introduces a flexible framework for distributed sign momentum with local steps to accelerate transformer pretraining under tight communication budgets. Local updates are performed with any base optimizer for $\tau$ steps, followed by a global sign momentum step that aggregates local differences via $\boldsymbol{x}_{t,0}-\boldsymbol{x}_{t,\tau}$ and updates the model with a Lion-like momentum term. The authors provide convergence analyses: a general result for randomized sign operators achieving $O(1/\sqrt{T})$ and an optimal $O(1/T^{1/4})$ rate in $\ell_1$ norm for the exact sign operator with SGD. Empirically, the method yields significant improvements over multi-local-step baselines like SlowMo on GPT-2 pretraining across multiple model scales, while reducing communication dramatically. The work offers a principled, scalable approach to distributed transformer training and motivates further exploration in other domains such as Vision Transformers.

Abstract

Pre-training Transformer models is resource-intensive, and recent studies have shown that sign momentum is an efficient technique for training large-scale deep learning models, particularly Transformers. However, its application in distributed training remains underexplored. This paper investigates a novel communication-efficient distributed sign momentum method with multiple local steps, to cope with the scenarios where communicating at every step is prohibitive. Our proposed method allows for a broad class of base optimizers for local steps, and uses sign momentum in the global step, where momentum is generated from differences accumulated during local steps. For generic base optimizers, by approximating the sign operator with a randomized version that acts as a continuous analog in expectation, we present a general convergence analysis, which specializes to an $O(1/\sqrt{T})$ rate for a particular instance. When local step is stochastic gradient descent, we show an optimal $O(1/T^{1/4})$ rate in terms of $\ell_1$ gradient norm for nonconvex smooth cost functions. We extensively evaluate our method on the pre-training of various sized GPT-2 models from scratch, and the empirical results show significant improvement compared to other distributed methods with multiple local steps.

Distributed Sign Momentum with Local Steps for Training Transformers

TL;DR

This paper introduces a flexible framework for distributed sign momentum with local steps to accelerate transformer pretraining under tight communication budgets. Local updates are performed with any base optimizer for steps, followed by a global sign momentum step that aggregates local differences via and updates the model with a Lion-like momentum term. The authors provide convergence analyses: a general result for randomized sign operators achieving and an optimal rate in norm for the exact sign operator with SGD. Empirically, the method yields significant improvements over multi-local-step baselines like SlowMo on GPT-2 pretraining across multiple model scales, while reducing communication dramatically. The work offers a principled, scalable approach to distributed transformer training and motivates further exploration in other domains such as Vision Transformers.

Abstract

Pre-training Transformer models is resource-intensive, and recent studies have shown that sign momentum is an efficient technique for training large-scale deep learning models, particularly Transformers. However, its application in distributed training remains underexplored. This paper investigates a novel communication-efficient distributed sign momentum method with multiple local steps, to cope with the scenarios where communicating at every step is prohibitive. Our proposed method allows for a broad class of base optimizers for local steps, and uses sign momentum in the global step, where momentum is generated from differences accumulated during local steps. For generic base optimizers, by approximating the sign operator with a randomized version that acts as a continuous analog in expectation, we present a general convergence analysis, which specializes to an rate for a particular instance. When local step is stochastic gradient descent, we show an optimal rate in terms of gradient norm for nonconvex smooth cost functions. We extensively evaluate our method on the pre-training of various sized GPT-2 models from scratch, and the empirical results show significant improvement compared to other distributed methods with multiple local steps.

Paper Structure

This paper contains 28 sections, 5 theorems, 70 equations, 7 figures, 6 tables, 5 algorithms.

Key Result

Lemma 1

For a random vector $\boldsymbol{v} \in \mathbb{R}^d$ that satisfies $\|{\bm{v}}\| \le B$ almost surely, we have $\mathbb{E}_{\mathcal{S}}[{\bm{v}}] = {\bm{v}}/B$, and $\mathbb{E}_{\mathcal{S}}[\| \mathcal{S}_r({\bm{v}}) - {\bm{v}}/B \|^2] \le d$.

Figures (7)

  • Figure 1: Validation loss curves versus communication rounds for GPT-2 small, medium, and large. Communication interval for our Algorithm \ref{['alg:dist-sign-mom']} and SlowMo are set as $\tau = 12$.
  • Figure 2: Validation loss curves versus computation rounds for GPT-2 small, medium, and large. Communication interval for our Algorithm \ref{['alg:dist-sign-mom']} and SlowMo are set as $\tau = 12$.
  • Figure 3: Validation loss curves for communication interval $\tau = 12, 24.$
  • Figure 4: Training loss curves for communication interval $\tau = 12$.
  • Figure 5: Validation loss curves for communication interval $\tau = 24$.
  • ...and 2 more figures

Theorems & Definitions (11)

  • Lemma 1
  • Theorem 1
  • Theorem 2
  • Remark 1
  • Theorem 3
  • Remark 2
  • proof
  • proof
  • Lemma 2
  • proof
  • ...and 1 more