Iris: First-Class Multi-GPU Programming Experience in Triton
Muhammad Awad, Muhammad Osama, Brandon Potter
TL;DR
Iris addresses the challenge of achieving high-performance, yet developer-friendly, multi-GPU programming by introducing a native Python+Triton library with tile-based symmetric memory. It enables fine-grained overlap between computation and communication by providing tensor-level, tile-centric primitives and a comprehensive taxonomy of unfused and fused patterns, demonstrated on GEMM+All-Scatter workloads. The results show near-optimal bandwidth and up to $1.79\times$ speedup over established baselines such as PyTorch and RCCL, highlighting substantial performance gains with simpler, compiler-visible programming. The work emphasizes that the barrier to fine-grained overlap is abstraction, not hardware limits, and proposes a pathway toward broader adoption through open-source, Triton-native primitives that integrate computation and communication under a single compiler view.
Abstract
Multi-GPU programming traditionally requires developers to navigate complex trade-offs between performance and programmability. High-performance implementations typically rely on low-level HIP/CUDA communication libraries that demand substantial engineering effort for even basic overlap patterns, while simpler abstractions often sacrifice performance. We present Iris, a multi-GPU communication library implemented entirely in Python and Triton that eliminates this trade-off. Iris provides tile-based symmetric memory abstractions that naturally align with Triton's programming model, enabling developers to write single-source kernels that seamlessly interleave computation and communication. We demonstrate a taxonomy of compute-communication overlap patterns--from bulk-synchronous to fine-grained workgroup specialization--that can be implemented with minimal code changes in Iris, often requiring just a few additional lines within the same Triton kernel. Our evaluation shows that Iris achieves near-optimal bandwidth utilization in microbenchmarks and delivers up to 1.79x speedup over PyTorch and RCCL for GEMM+All-Scatter workloads, demonstrating that high-level implementations can match or exceed heavily-optimized libraries while dramatically simplifying multi-GPU programming.
