Table of Contents
Fetching ...

PartIR: Composing SPMD Partitioning Strategies for Machine Learning

Sami Alabed, Daniel Belov, Bart Chrzaszcz, Juliana Franco, Dominik Grewe, Dougal Maclaurin, James Molloy, Tom Natan, Tamara Norman, Xiaoyue Pan, Adam Paszke, Norman A. Rink, Michael Schaarschmidt, Timur Sitdikov, Agnieszka Swietlik, Dimitrios Vytiniotis, Joel Wee

TL;DR

PartIR tackles the growing complexity of training large neural networks by decoupling partitioning strategies from model code and providing a tactic-driven, MLIR-based compiler stack. It enables expressive composition of data-, model-, and optimizer-sharding strategies through manual and automatic tactics, with propagation-based rewriting that guarantees conflict-free sharding and analytic performance validation via a simulator. The system upper-caps with a formal correctness pathway for translating PartIR:Core to PartIR:HLO, demonstrates near-SOTA MFU and comparable memory usage against GSPMD, and shows strong benefits from incremental propagation and schedulable automation. These properties yield a practical, scalable partitioning workflow that reduces manual tuning and offers verifiable performance guarantees for large, heterogeneous model deployments.

Abstract

Training of modern large neural networks (NN) requires a combination of parallelization strategies encompassing data, model, or optimizer sharding. When strategies increase in complexity, it becomes necessary for partitioning tools to be 1) expressive, allowing the composition of simpler strategies, and 2) predictable to estimate performance analytically. We present PartIR, our design for a NN partitioning system. PartIR is focused on an incremental approach to rewriting and is hardware-and-runtime agnostic. We present a simple but powerful API for composing sharding strategies and a simulator to validate them. The process is driven by high-level programmer-issued partitioning tactics, which can be both manual and automatic. Importantly, the tactics are specified separately from the model code, making them easy to change. We evaluate PartIR on several different models to demonstrate its predictability, expressibility, and ability to reach peak performance..

PartIR: Composing SPMD Partitioning Strategies for Machine Learning

TL;DR

PartIR tackles the growing complexity of training large neural networks by decoupling partitioning strategies from model code and providing a tactic-driven, MLIR-based compiler stack. It enables expressive composition of data-, model-, and optimizer-sharding strategies through manual and automatic tactics, with propagation-based rewriting that guarantees conflict-free sharding and analytic performance validation via a simulator. The system upper-caps with a formal correctness pathway for translating PartIR:Core to PartIR:HLO, demonstrates near-SOTA MFU and comparable memory usage against GSPMD, and shows strong benefits from incremental propagation and schedulable automation. These properties yield a practical, scalable partitioning workflow that reduces manual tuning and offers verifiable performance guarantees for large, heterogeneous model deployments.

Abstract

Training of modern large neural networks (NN) requires a combination of parallelization strategies encompassing data, model, or optimizer sharding. When strategies increase in complexity, it becomes necessary for partitioning tools to be 1) expressive, allowing the composition of simpler strategies, and 2) predictable to estimate performance analytically. We present PartIR, our design for a NN partitioning system. PartIR is focused on an incremental approach to rewriting and is hardware-and-runtime agnostic. We present a simple but powerful API for composing sharding strategies and a simulator to validate them. The process is driven by high-level programmer-issued partitioning tactics, which can be both manual and automatic. Importantly, the tactics are specified separately from the model code, making them easy to change. We evaluate PartIR on several different models to demonstrate its predictability, expressibility, and ability to reach peak performance..
Paper Structure (81 sections, 5 theorems, 29 equations, 18 figures, 4 tables)

This paper contains 81 sections, 5 theorems, 29 equations, 18 figures, 4 tables.

Key Result

theorem 1

Let $\Gamma \vdash^{\space{{\tt \scriptsize{core}}}} {\tt {C}}\xspace[e] : {\small\tt tensor}\xspace\langle\overline{n}\rangle\xspace$, let $\Gamma\vdash\mathcal{M}$, and let $\overline{\sigma}$ be such that ${\it axes}(\overline{\sigma})\xspace = a_1\cdots a_k$, where $r_{a_1},\ldots,r_{a_k}$ are t

Figures (18)

  • Figure 1: Top: batch parallelism, the gradients are AllReduced to update the parameters. Bottom: Z3/FSDP, note the parameters are sharded and only AllGathered before their use, highlighted in thick blue bars in the figure. The gradients are ReducedScattered before updating the parameters.
  • Figure 2: Batch (N) and model (M) parallelism. On the left, the communication along the M axis (e.g., Megatron's megatron2019 activation reductions). On the right, the communication along the N axis (e.g., gradient reductions). Each device parameter shard is color-coded; all devices along N hold the same shard.
  • Figure 3: PartIR partitioning stack, built using MLIR, supporting layering of new operators on top of existing ones -- hence, we use "+" to signify the introduction of new operators, and "-" to signify that operators have now become illegal.
  • Figure 4: Programs equivalent to , where is a mesh axis of size 4, assuming and .
  • Figure 5: Demonstration of sequences of and collectives on a mesh of four devices , each device is represented as a box in the figure. Top: all devices hold the same 2D array; bottom: data is sliced row-wise along axis "y"; right: data is further sliced column-wise along axis "x". In each case we give the device-local tensor types.
  • ...and 13 more figures

Theorems & Definitions (7)

  • definition 1: Judgement $\Gamma\vdash\mathcal{M}$, Associated SPMD Typing Context
  • theorem 1: Typing simulation
  • definition 2: Environment relation
  • lemma 1
  • lemma 2: and Definition of Extension of Environments
  • lemma 3
  • theorem 2: Correctness of translation