LoongTrain: Efficient Training of Long-Sequence LLMs with Head-Context Parallelism
Diandian Gu, Peng Sun, Qinghao Hu, Ting Huang, Xun Chen, Yingtong Xiong, Guoteng Wang, Qiaoling Chen, Shangchun Zhao, Jiarui Fang, Yonggang Wen, Tianwei Zhang, Xin Jin, Xuanzhe Liu
TL;DR
LoongTrain tackles the core challenge of training LLMs on very long sequences by introducing 2D-Attention, which merges head-parallel and context-parallel processing to overcome HP scaling limits while mitigating CP communication bottlenecks. The key innovations include KV replication for GQA, Double-Ring-Attention for efficient interconnect usage, and selective gradient checkpointing combined with Hybrid ZeRO to trim memory and maintain throughput. Comprehensive analysis and extensive experiments show significant end-to-end speedups and MFU gains over state-of-the-art baselines, with robust scalability up to long sequence lengths and large GPU clusters. The system is implemented in an internal framework and demonstrates practical impact for large-scale, long-context LLM training.
Abstract
Efficiently training LLMs with long sequences is important yet challenged by the massive computation and memory requirements. Sequence parallelism has been proposed to tackle these problems, but existing methods suffer from scalability or efficiency issues. We propose LoongTrain, a novel system to efficiently train LLMs with long sequences at scale. The core of LoongTrain is the 2D-Attention mechanism, which combines both head-parallel and context-parallel techniques to break the scalability constraints while maintaining efficiency. We introduce Double-Ring-Attention and analyze the performance of device placement strategies to further speed up training. We implement LoongTrain with the hybrid ZeRO and Selective Checkpoint++ techniques. Experiment results show that LoongTrain outperforms state-of-the-art baselines, i.e., DeepSpeed-Ulysses and Megatron Context Parallelism, in both end-to-end training speed and scalability, and improves Model FLOPs Utilization (MFU) by up to 2.88x.
