Table of Contents
Fetching ...

Optimizing Large Model Training through Overlapped Activation Recomputation

Ping Chen, Wenjie Zhang, Shuibing He, Weijian Chen, Siling Yang, Kexin Huang, Yanlong Yin, Xuan Zhan, Yingjie Gu, Zhuwei Peng, Yi Zheng, Zhefeng Wang, Gang Chen

TL;DR

The paper tackles memory bottlenecks in training large DNNs by addressing recomputation overhead on the critical path. It introduces Lynx, a framework that overlaps activation recomputation with communication and employs memory-aware, recomputation-driven model partitioning. Two scheduling strategies are developed: Lynx-OPT (MILP) for an optimal upper bound and Lynx-HEU (heuristic) that generalizes policies across identical structures to achieve near-optimal throughput with minimal search time; a recomputation-aware partitioner ensures balanced stage workloads. Empirical results on GPT-scale models (1.3B–23B) across NVLink, PCIe, and Ascend clusters show substantial gains (up to 1.37x) over state-of-the-art baselines, underscoring Lynx’s practical impact for scalable deep learning.

Abstract

Large model training often uses recomputation to alleviate memory pressure and pipelines to exploit the parallelism of data, tensors, and devices. However, existing recomputation approaches may incur high overhead when training real-world models, as they are executed on demand in the critical training path. In this paper, we present Lynx, a new recomputation framework to reduce overhead by overlapping recomputation with communication in training pipelines. To reduce the large search space for recomputation strategies, we propose a heuristic-based recomputation scheduling algorithm, which is based on the observation that there are identical structures in large DNN models so that we can apply the same scheduling policy to all such structures. Additionally, we propose a recomputation-aware model partitioning method to balance each stage's execution time for improved training throughput. Our comprehensive evaluation using GPT models with 1.3B-23B parameters shows that Lynx outperforms existing recomputation approaches by up to 1.37x.

Optimizing Large Model Training through Overlapped Activation Recomputation

TL;DR

The paper tackles memory bottlenecks in training large DNNs by addressing recomputation overhead on the critical path. It introduces Lynx, a framework that overlaps activation recomputation with communication and employs memory-aware, recomputation-driven model partitioning. Two scheduling strategies are developed: Lynx-OPT (MILP) for an optimal upper bound and Lynx-HEU (heuristic) that generalizes policies across identical structures to achieve near-optimal throughput with minimal search time; a recomputation-aware partitioner ensures balanced stage workloads. Empirical results on GPT-scale models (1.3B–23B) across NVLink, PCIe, and Ascend clusters show substantial gains (up to 1.37x) over state-of-the-art baselines, underscoring Lynx’s practical impact for scalable deep learning.

Abstract

Large model training often uses recomputation to alleviate memory pressure and pipelines to exploit the parallelism of data, tensors, and devices. However, existing recomputation approaches may incur high overhead when training real-world models, as they are executed on demand in the critical training path. In this paper, we present Lynx, a new recomputation framework to reduce overhead by overlapping recomputation with communication in training pipelines. To reduce the large search space for recomputation strategies, we propose a heuristic-based recomputation scheduling algorithm, which is based on the observation that there are identical structures in large DNN models so that we can apply the same scheduling policy to all such structures. Additionally, we propose a recomputation-aware model partitioning method to balance each stage's execution time for improved training throughput. Our comprehensive evaluation using GPT models with 1.3B-23B parameters shows that Lynx outperforms existing recomputation approaches by up to 1.37x.
Paper Structure (24 sections, 9 equations, 17 figures, 3 tables)

This paper contains 24 sections, 9 equations, 17 figures, 3 tables.

Figures (17)

  • Figure 1: The training workflow of tensor parallelism. The shaded rectangle indicates the splitting of the tensor onto another GPU for parallel training. $g$ denotes the all-reduce operation in the forward and backward.
  • Figure 2: The training workflow of pipeline parallelism (one-forward-one-backward). Each minibatch consists of 5 micro batches. The example illustrates that ideal computation-balanced model partitioning achieves the best training performance.
  • Figure 3: (a) The ratio of TP communication during training. The $x$-axis represents the number of GPUs in a TP group. (b) Imbalanced stage (GPU) memory consumption (TP=2, PP=4). The memory usage is normalize with that of stage0.
  • Figure 4: An example of forward, backward, and recomputation processes. $T_1$ is evicted at time $t_1$ and can be recomputed anytime between $t_1$ and $t_2$.
  • Figure 5: Overview of Lynx.
  • ...and 12 more figures