Table of Contents
Fetching ...

EDiT: A Local-SGD-Based Efficient Distributed Training Method for Large Language Models

Jialiang Cheng, Ning Gao, Yun Yue, Zhiling Ye, Jiadi Jiang, Jian Sha

TL;DR

EDiT addresses critical bottlenecks in distributed LLM pretraining by fusing model sharding with a Local SGD framework on a 2D device mesh, enabling layer-wise, overlapped synchronization that reduces communication and memory overhead. It introduces a pseudo gradient penalty to stabilize training in diverse data regimes and an asynchronous variant, A-EDiT, to handle heterogeneity in large clusters. Empirical results across LLama scales (350M–7B) and multiple datasets show improved convergence, generalization, and throughput, with robust performance under stragglers and elastic resource changes. Theoretical analysis provides a convergence guarantee of $O\left(\dfrac{\log(T)}{\sqrt{T}}\right)$ under standard assumptions, supporting the practicality of scalable, asynchronous distributed training for LLMs.

Abstract

Distributed training methods are crucial for large language models (LLMs). However, existing distributed training methods often suffer from communication bottlenecks, stragglers, and limited elasticity, particularly in heterogeneous or large-scale environments. Local SGD methods have been proposed to address these issues, but their effectiveness remains limited to small-scale training due to additional memory overhead and lack of concerns on efficiency and stability. To tackle these issues, we propose EDiT, an innovative Efficient Distributed Training method that combines a tailored Local SGD approach with model sharding techniques to enhance large-scale training efficiency. EDiT performs layer-wise parameter synchronization during forward pass, reducing communication and memory overhead and enabling overlap. Besides, EDiT employs a pseudo gradient penalty strategy to suppress loss spikes, which ensures training stability and improves performance. Additionally, we introduce A-EDiT, a fully asynchronous variant of EDiT that accommodates heterogeneous clusters. Building on EDiT/A-EDiT, we conduct a series of experiments to validate large-scale asynchronous training for LLMs, accompanied by comprehensive analyses. Experimental results demonstrate the superior performance of EDiT/A-EDiT, establishing them as robust solutions for distributed LLM training in diverse computational ecosystems. The code is available at Atorch codebase: https://github.com/intelligent-machine-learning/atorch/tree/main/atorch/local_sgd.

EDiT: A Local-SGD-Based Efficient Distributed Training Method for Large Language Models

TL;DR

EDiT addresses critical bottlenecks in distributed LLM pretraining by fusing model sharding with a Local SGD framework on a 2D device mesh, enabling layer-wise, overlapped synchronization that reduces communication and memory overhead. It introduces a pseudo gradient penalty to stabilize training in diverse data regimes and an asynchronous variant, A-EDiT, to handle heterogeneity in large clusters. Empirical results across LLama scales (350M–7B) and multiple datasets show improved convergence, generalization, and throughput, with robust performance under stragglers and elastic resource changes. Theoretical analysis provides a convergence guarantee of under standard assumptions, supporting the practicality of scalable, asynchronous distributed training for LLMs.

Abstract

Distributed training methods are crucial for large language models (LLMs). However, existing distributed training methods often suffer from communication bottlenecks, stragglers, and limited elasticity, particularly in heterogeneous or large-scale environments. Local SGD methods have been proposed to address these issues, but their effectiveness remains limited to small-scale training due to additional memory overhead and lack of concerns on efficiency and stability. To tackle these issues, we propose EDiT, an innovative Efficient Distributed Training method that combines a tailored Local SGD approach with model sharding techniques to enhance large-scale training efficiency. EDiT performs layer-wise parameter synchronization during forward pass, reducing communication and memory overhead and enabling overlap. Besides, EDiT employs a pseudo gradient penalty strategy to suppress loss spikes, which ensures training stability and improves performance. Additionally, we introduce A-EDiT, a fully asynchronous variant of EDiT that accommodates heterogeneous clusters. Building on EDiT/A-EDiT, we conduct a series of experiments to validate large-scale asynchronous training for LLMs, accompanied by comprehensive analyses. Experimental results demonstrate the superior performance of EDiT/A-EDiT, establishing them as robust solutions for distributed LLM training in diverse computational ecosystems. The code is available at Atorch codebase: https://github.com/intelligent-machine-learning/atorch/tree/main/atorch/local_sgd.

Paper Structure

This paper contains 22 sections, 1 theorem, 19 equations, 10 figures, 7 tables, 2 algorithms.

Key Result

Theorem 1

Suppose that the following assumptions are satisfied: Then Algorithm alg:illustration yields where the meaning of $n$, $\phi$ and $\epsilon$ are listed in Table tab:notation of Appendix appendix:convergence.

Figures (10)

  • Figure 1: The schematic illustration of our proposed EDiT method with $4$ workers and a $2 \times 2$ device mesh as an example. The left part shows the communication groups and parameter sharding, and the right part presents the detailed computation and communication flows within worker B.
  • Figure 2: Illustration of model synchronization and our proposed pseudo gradient penalty method, depicted with an example of four workers in a model sync group.
  • Figure 3: A comparison of the synchronization scheme of EDiT and A-EDiT.
  • Figure 4: The loss and PPL curves of different methods on the (a) & (b) FineWeb-Edu dataset and (c) & (d) in-house dataset. The final values are marked, with the best ones in bold. Here we use the average of the last 10 values as results to prevent randomness. PLS is short for Post Local SGD.
  • Figure 5: The TFLOPS of different methods under different training scenarios.
  • ...and 5 more figures

Theorems & Definitions (2)

  • Theorem 1
  • proof