Distributed Sign Momentum with Local Steps for Training Transformers
Shuhua Yu, Ding Zhou, Cong Xie, An Xu, Zhi Zhang, Xin Liu, Soummya Kar
TL;DR
This paper introduces a flexible framework for distributed sign momentum with local steps to accelerate transformer pretraining under tight communication budgets. Local updates are performed with any base optimizer for $\tau$ steps, followed by a global sign momentum step that aggregates local differences via $\boldsymbol{x}_{t,0}-\boldsymbol{x}_{t,\tau}$ and updates the model with a Lion-like momentum term. The authors provide convergence analyses: a general result for randomized sign operators achieving $O(1/\sqrt{T})$ and an optimal $O(1/T^{1/4})$ rate in $\ell_1$ norm for the exact sign operator with SGD. Empirically, the method yields significant improvements over multi-local-step baselines like SlowMo on GPT-2 pretraining across multiple model scales, while reducing communication dramatically. The work offers a principled, scalable approach to distributed transformer training and motivates further exploration in other domains such as Vision Transformers.
Abstract
Pre-training Transformer models is resource-intensive, and recent studies have shown that sign momentum is an efficient technique for training large-scale deep learning models, particularly Transformers. However, its application in distributed training remains underexplored. This paper investigates a novel communication-efficient distributed sign momentum method with multiple local steps, to cope with the scenarios where communicating at every step is prohibitive. Our proposed method allows for a broad class of base optimizers for local steps, and uses sign momentum in the global step, where momentum is generated from differences accumulated during local steps. For generic base optimizers, by approximating the sign operator with a randomized version that acts as a continuous analog in expectation, we present a general convergence analysis, which specializes to an $O(1/\sqrt{T})$ rate for a particular instance. When local step is stochastic gradient descent, we show an optimal $O(1/T^{1/4})$ rate in terms of $\ell_1$ gradient norm for nonconvex smooth cost functions. We extensively evaluate our method on the pre-training of various sized GPT-2 models from scratch, and the empirical results show significant improvement compared to other distributed methods with multiple local steps.
