Table of Contents
Fetching ...

DeepCompile: A Compiler-Driven Approach to Optimizing Distributed Deep Learning Training

Masahiro Tanaka, Du Li, Umesh Chand, Ali Zafar, Haiying Shen, Olatunji Ruwase

TL;DR

DeepCompile tackles the inefficiency of fully sharded distributed training by transforming user models into computation graphs and applying profiling-guided optimization passes that coordinate prefetching, unsharding, and adaptive offloading under dynamic memory usage. By operating at graph level and using intermediate profiling to guide passes, DeepCompile adapts to runtime memory patterns and data dependencies to maximize communication–computation overlap. It implements a fully sharded baseline and three optimization passes, achieving up to 1.28×/1.54× improvements over ZeRO-3/FSDP and up to 7.01× throughput with limited GPUs via offloading on Llama-3 70B and Mixtral 8×7B MoE, with a one-time compilation cost that can be cached. The results demonstrate a practical, compiler-driven path to more efficient distributed training, offering scalable memory and throughput benefits without requiring model code changes.

Abstract

The increasing scale of deep learning models has led to the development of various parallelization strategies for distributed training across accelerators. For example, fully sharded approaches like DeepSpeed ZeRO-3 and FSDP partition the parameters of each layer across multiple GPUs and gather them through communication when needed. These methods rely on optimizations such as prefetching, which initiates communication early to overlap it with computation and reduce communication overhead, and unsharding, which retains as many parameters in their unsharded form as possible to reduce communication volume. Although the timing of prefetching should be adjusted in response to dynamic memory usage during execution, these systems lack the flexibility to control it, which limits the benefits of prefetching. Moreover, they cannot anticipate how memory usage will change after prefetching is applied, making it difficult to combine it effectively with other optimizations such as unsharding. We present DeepCompile, which compiles user-defined models into computation graphs and applies a sequence of profiling-guided optimization passes for distributed training. Taking dynamic memory usage into account, these passes flexibly insert, reorder, or remove operations to improve communication-computation overlap, reduce memory pressure, and coordinate multiple optimizations in a unified manner. To evaluate the effectiveness of this design, we implemented a fully sharded approach like ZeRO-3 and FSDP on top of DeepCompile, along with three optimizations: proactive prefetching, selective unsharding, and adaptive offloading. We evaluate DeepCompile on the training of Llama 3 70B and Mixtral 8x7B MoE models. DeepCompile achieves up to 1.28x and 1.54x performance improvements over ZeRO-3 and FSDP baselines, respectively, and up to a 7.01x throughput increase with limited GPU resources, using offloading.

DeepCompile: A Compiler-Driven Approach to Optimizing Distributed Deep Learning Training

TL;DR

DeepCompile tackles the inefficiency of fully sharded distributed training by transforming user models into computation graphs and applying profiling-guided optimization passes that coordinate prefetching, unsharding, and adaptive offloading under dynamic memory usage. By operating at graph level and using intermediate profiling to guide passes, DeepCompile adapts to runtime memory patterns and data dependencies to maximize communication–computation overlap. It implements a fully sharded baseline and three optimization passes, achieving up to 1.28×/1.54× improvements over ZeRO-3/FSDP and up to 7.01× throughput with limited GPUs via offloading on Llama-3 70B and Mixtral 8×7B MoE, with a one-time compilation cost that can be cached. The results demonstrate a practical, compiler-driven path to more efficient distributed training, offering scalable memory and throughput benefits without requiring model code changes.

Abstract

The increasing scale of deep learning models has led to the development of various parallelization strategies for distributed training across accelerators. For example, fully sharded approaches like DeepSpeed ZeRO-3 and FSDP partition the parameters of each layer across multiple GPUs and gather them through communication when needed. These methods rely on optimizations such as prefetching, which initiates communication early to overlap it with computation and reduce communication overhead, and unsharding, which retains as many parameters in their unsharded form as possible to reduce communication volume. Although the timing of prefetching should be adjusted in response to dynamic memory usage during execution, these systems lack the flexibility to control it, which limits the benefits of prefetching. Moreover, they cannot anticipate how memory usage will change after prefetching is applied, making it difficult to combine it effectively with other optimizations such as unsharding. We present DeepCompile, which compiles user-defined models into computation graphs and applies a sequence of profiling-guided optimization passes for distributed training. Taking dynamic memory usage into account, these passes flexibly insert, reorder, or remove operations to improve communication-computation overlap, reduce memory pressure, and coordinate multiple optimizations in a unified manner. To evaluate the effectiveness of this design, we implemented a fully sharded approach like ZeRO-3 and FSDP on top of DeepCompile, along with three optimizations: proactive prefetching, selective unsharding, and adaptive offloading. We evaluate DeepCompile on the training of Llama 3 70B and Mixtral 8x7B MoE models. DeepCompile achieves up to 1.28x and 1.54x performance improvements over ZeRO-3 and FSDP baselines, respectively, and up to a 7.01x throughput increase with limited GPU resources, using offloading.

Paper Structure

This paper contains 24 sections, 10 figures, 2 tables, 2 algorithms.

Figures (10)

  • Figure 1: Memory usage trends and scheduling opportunities for prefetching and offloading. Profile of several final layers with a sequence length of 4096 and a vocabulary size of 128k. Significant memory spikes are observed in the log-softmax and negative log-likelihood loss layers.
  • Figure 2: Workflow of compilation and optimization with DeepCompile
  • Figure 3: Optimization and profiling loop in DeepCompile
  • Figure 4: Graph modification
  • Figure 5: Parameter sharding and prefetching
  • ...and 5 more figures