Scaling Deep Learning Training with MPMD Pipeline Parallelism
Anxhelo Xhebraj, Sean Lee, Hanfeng Chen, Vinod Grover
TL;DR
JaxPP introduces a novel MPMD pipeline parallelism system built on top of JAX/XLA to overcome the limitations of traditional SPMD approaches for large-scale model training. By enabling user-defined pipeline schedules through a gradient-accumulation loop, stage marking, and placement inference, it achieves asynchronous, distributed execution with automatic communication handling via a single-controller runtime. The approach yields substantial throughput and utilization gains over SPMD PP and competitive performance relative to NeMo, while reducing the coding burden for complex parallel configurations. This work enables more flexible, scalable training of giant models across clusters with varying interconnect bandwidths, with practical impacts for research and deployment of large-scale DL systems.
Abstract
We present JaxPP, a system for efficiently scaling the training of large deep learning models with flexible pipeline parallelism. We introduce a seamless programming model that allows implementing user-defined pipeline schedules for gradient accumulation. JaxPP automatically distributes tasks, corresponding to pipeline stages, over a cluster of nodes and automatically infers the communication among them. We implement a MPMD runtime for asynchronous execution of SPMD tasks. The pipeline parallelism implementation of JaxPP improves hardware utilization by up to $1.11\times$ with respect to the best performing SPMD configuration.
