ML-Triton, A Multi-Level Compilation and Language Extension to Triton GPU Programming
Dewei Wang, Wei Zhu, Liyang Ling, Ettore Tiotto, Quintin Wang, Whitney Tsang, Julian Opperman, Jacky Deng
TL;DR
The paper addresses the mismatch between GPU hardware hierarchy and single-level lowering in Triton by proposing ML-Triton, a multi-level compilation flow that starts at the workgroup level and progressively lowers toward warp-level intrinsics. It extends the Triton language with a Triton dialect (tt) and layout encoding to expose explicit data partitioning and warp-level programming, enabling near out-of-the-box performance on modern GPUs. Key contributions include a hierarchical lowering strategy, compiler hints, and warp-level programming support, with reported results showing performance above 95% of expert-written kernels on Intel GPUs. This approach improves portability and accessibility for researchers while maintaining high efficiency, particularly for dense ops like GEMM and MHA on AI workloads.
Abstract
In the era of LLMs, dense operations such as GEMM and MHA are critical components. These operations are well-suited for parallel execution using a tilebased approach. While traditional GPU programming often relies on low level interfaces like CUDA or SYCL, Triton has emerged as a DSL that offers a more user-friendly and portable alternative by programming at a higher level. The current Triton starts at the workgroup (aka threadblock) level, and directly lowers to per-thread level. And then attempt to coalesce and amend through a series of passes, promoting information from low-level representation. We believe this is pre-mature lowering based on the below observations. 1. GPU has a hierarchical structure both physically and logically. Modern GPUs often feature SIMD units capable of directly operating on tiles on a warp or warpgroup basis, such as blocked load and blocked MMA. 2. Multi-level gradual lowering can make compiler decoupled and clean by separating considerations inter and intra a logical layer. 3. Kernel developers often need fine control to get good performance on the latest hardware. FlashAttention2 advocates explicit data partition between warps to make a performance boost. In this context, we propose ML-Triton which features multi-level compilation flow and programming interface. Our approach begins at the workgroup level and progressively lowers to the warp and intrinsic level, implementing a multilevel lowering align with the hierarchical nature of GPU. Additionally, we extend triton language to support user-set compiler hint and warp level programming, enabling researchers to get good out-of-the box performance without awaiting compiler updates. Experimental results demonstrate that our approach achieves performance above 95% of expert-written kernels on Intel GPU, as measured by the geometric mean.
