Table of Contents
Fetching ...

SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile

Ruisi Zhang, Tianyu Liu, Will Feng, Andrew Gu, Sanket Purandare, Wanchao Liang, Francisco Massa

TL;DR

This paper presents SimpleFSDP, a PyTorch-native compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations.

Abstract

Distributed training of large models consumes enormous computation resources and requires substantial engineering efforts to compose various training techniques. This paper presents SimpleFSDP, a PyTorch-native compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. SimpleFSDP's novelty lies in its unique $torch.compile$-friendly implementation of collective communications using existing PyTorch primitives, namely parametrizations, selective activation checkpointing, and DTensor. It also features the first-of-its-kind intermediate representation (IR) nodes bucketing and reordering in the TorchInductor backend for effective computation-communication overlapping. As a result, users can employ the aforementioned optimizations to automatically or manually wrap model components for minimal communication exposure. Extensive evaluations of SimpleFSDP on Llama 3 models (including the ultra-large 405B) using TorchTitan demonstrate up to 28.54% memory reduction and 68.67% throughput improvement compared to the most widely adopted FSDP2 eager framework, when composed with other distributed training techniques.

SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile

TL;DR

This paper presents SimpleFSDP, a PyTorch-native compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations.

Abstract

Distributed training of large models consumes enormous computation resources and requires substantial engineering efforts to compose various training techniques. This paper presents SimpleFSDP, a PyTorch-native compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. SimpleFSDP's novelty lies in its unique -friendly implementation of collective communications using existing PyTorch primitives, namely parametrizations, selective activation checkpointing, and DTensor. It also features the first-of-its-kind intermediate representation (IR) nodes bucketing and reordering in the TorchInductor backend for effective computation-communication overlapping. As a result, users can employ the aforementioned optimizations to automatically or manually wrap model components for minimal communication exposure. Extensive evaluations of SimpleFSDP on Llama 3 models (including the ultra-large 405B) using TorchTitan demonstrate up to 28.54% memory reduction and 68.67% throughput improvement compared to the most widely adopted FSDP2 eager framework, when composed with other distributed training techniques.

Paper Structure

This paper contains 23 sections, 5 figures, 6 tables, 1 algorithm.

Figures (5)

  • Figure 1: SimpleFSDP's frontend implementation.
  • Figure 2: An overview of SimpleFSDP's optimizations and model wrapping. The left side is the forward pass, and the right side is the backward pass. We show the IR node scheduling in TorchInductor and the corresponding execution order in GPU in the yellow box. The blue box indicates the IR nodes are from the same module. In the Manual Wrapping, the all-gathers (AG) and reduce-scatters (RS) from the same module are bucketed as new communication IR nodes. Then, the bucketed communication and computation are reordered to enable communication prefetch during the current computation. In the Auto Wrapping, the all-gather and reduce-scatter are bucketed as long as the bucketed communication can be overlapped by the current computation and does not exceed the memory limit. Then, the bucketed communication and computation are reordered to hide communication exposure.
  • Figure 3: SimpleFSDP performance on LLaMA-3 8B, 70B, and 405B models when training on different numbers of H100 GPUs. We report the peak memory in GiB and the throughputs in TPS (tokens/s).
  • Figure 4: Auto-Wrapping performance when training Llama 3.1 8B and 70B models on different numbers of H100 GPUs.
  • Figure 5: Loss curve of FSDP2 and SimpleFSDP on Llama 3.1 8B.