Table of Contents
Fetching ...

Scaling Deep Learning Training with MPMD Pipeline Parallelism

Anxhelo Xhebraj, Sean Lee, Hanfeng Chen, Vinod Grover

TL;DR

JaxPP introduces a novel MPMD pipeline parallelism system built on top of JAX/XLA to overcome the limitations of traditional SPMD approaches for large-scale model training. By enabling user-defined pipeline schedules through a gradient-accumulation loop, stage marking, and placement inference, it achieves asynchronous, distributed execution with automatic communication handling via a single-controller runtime. The approach yields substantial throughput and utilization gains over SPMD PP and competitive performance relative to NeMo, while reducing the coding burden for complex parallel configurations. This work enables more flexible, scalable training of giant models across clusters with varying interconnect bandwidths, with practical impacts for research and deployment of large-scale DL systems.

Abstract

We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to $1.11\times$ with respect to the best performing SPMD configuration.

Scaling Deep Learning Training with MPMD Pipeline Parallelism

TL;DR

JaxPP introduces a novel MPMD pipeline parallelism system built on top of JAX/XLA to overcome the limitations of traditional SPMD approaches for large-scale model training. By enabling user-defined pipeline schedules through a gradient-accumulation loop, stage marking, and placement inference, it achieves asynchronous, distributed execution with automatic communication handling via a single-controller runtime. The approach yields substantial throughput and utilization gains over SPMD PP and competitive performance relative to NeMo, while reducing the coding burden for complex parallel configurations. This work enables more flexible, scalable training of giant models across clusters with varying interconnect bandwidths, with practical impacts for research and deployment of large-scale DL systems.

Abstract

We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to with respect to the best performing SPMD configuration.

Paper Structure

This paper contains 26 sections, 2 equations, 12 figures, 1 table.

Figures (12)

  • Figure 1: Model implementation with named axes
  • Figure 2: Partitioning specification
  • Figure 3: Different parallelism instantiations depending on the mesh shape
  • Figure 5: Comparison between GPipe and 1F1B. In GPipe, at any time, all pipeline-parallel groups perform the same computation. Bubbles are implemented as redundant discarded computation (gray Z blocks). In 1F1B, all groups perform different computations.
  • Figure 6: System Overview. The left box shows the code in the driver process describing the computation and annotating pipeline stage boundaries. Auto-differentiation produces additional stages corresponding to the "backward" computations for the gradients. The user specifies a mapping of stages to SPMD actors and a schedule for the loop. Each call to the function schedules tasks
  • ...and 7 more figures