Composing Distributed Computations Through Task and Kernel Fusion
Rohan Yadav, Shiv Sundram, Wonchan Lee, Michael Garland, Michael Bauer, Alex Aiken, Fredrik Kjolstad
TL;DR
Diffuse tackles the challenge of efficiently composing distributed, task-based computations by introducing a scale-free intermediate representation that enables scalable, alias-aware fusion analyses across library boundaries. It pairs distributed task fusion with an MLIR-based JIT to fuse kernels inside fused tasks, yielding significant speedups on unmodified cuNumeric and Legate Sparse workloads and approaching or surpassing hand-tuned, MPI-based baselines. The approach is domain-agnostic, enabling cross-library optimizations and removing distributed-temporary data through temporary-store elimination and analysis memoization. Practically, Diffuse demonstrates 1.86x average speedups on up to 128 GPUs and up to 1.23x speedups over hand-optimized code, while maintaining compilation overhead within reasonable bounds, suggesting broad applicability for scalable, distributed scientific computing.
Abstract
We introduce Diffuse, a system that dynamically performs task and kernel fusion in distributed, task-based runtime systems. The key component of Diffuse is an intermediate representation of distributed computation that enables the necessary analyses for the fusion of distributed tasks to be performed in a scalable manner. We pair task fusion with a JIT compiler to fuse together the kernels within fused tasks. We show empirically that Diffuse's intermediate representation is general enough to be a target for two real-world, task-based libraries (cuNumeric and Legate Sparse), letting Diffuse find optimization opportunities across function and library boundaries. Diffuse accelerates unmodified applications developed by composing task-based libraries by 1.86x on average (geo-mean), and by between 0.93x--10.7x on up to 128 GPUs. Diffuse also finds optimization opportunities missed by the original application developers, enabling high-level Python programs to match or exceed the performance of an explicitly parallel MPI library.
