SPLAT: A framework for optimised GPU code-generation for SParse reguLar ATtention
Ahan Gupta, Yueming Yuan, Devansh Jain, Yuhao Ge, David Aponte, Yanqi Zhou, Charith Mendis
TL;DR
SPLAT addresses the bottleneck of quadratic MHSA by introducing affine-compressed-sparse-row (ACSR), a regularity-informed sparse format, together with a just-in-time GPU code-generation pipeline. The framework develops specialized kernels (R-SDDMM and R-SpMM) and a robust tiling strategy (poset tiling) that exploit per-row affine-indices to reduce metadata and memory traffic while maintaining broad coverage of sparse-MHSA patterns. Empirical evaluation shows SPLAT delivering substantial speedups over both vendor libraries (up to $2.46\times$ in R-SDDMM and $2.81\times$ in R-SpMM on representative patterns) and hand-written kernels (up to $4.05\times$ for end-to-end sparse-MHSA on A100 GPUs), with additional gains from optimized data-layout choices and end-to-end code generation. The work demonstrates that exploiting regular sparsity with a specialized format and code generation can unify generality and high performance for sparse-MHSA, enabling scalable, efficient inference for long-context transformers in practice.
Abstract
Multi-head-self-attention (MHSA) mechanisms achieve state-of-the-art (SOTA) performance across natural language processing and vision tasks. However, their quadratic dependence on sequence lengths has bottlenecked inference speeds. To circumvent this bottleneck, researchers have proposed various sparse-MHSA models, where a subset of full attention is computed. Despite their promise, current sparse libraries and compilers do not support high-performance implementations for diverse sparse-MHSA patterns due to the underlying sparse formats they operate on. These formats, which are typically designed for high-performance & scientific computing applications, are either curated for extreme amounts of random sparsity (<1% non-zero values), or specific sparsity patterns. However, the sparsity patterns in sparse-MHSA are moderately sparse (10-50% non-zero values) and varied, resulting in existing sparse-formats trading off generality for performance. We bridge this gap, achieving both generality and performance, by proposing a novel sparse format: affine-compressed-sparse-row (ACSR) and supporting code-generation scheme, SPLAT, that generates high-performance implementations for diverse sparse-MHSA patterns on GPUs. Core to our proposed format and code generation algorithm is the observation that common sparse-MHSA patterns have uniquely regular geometric properties. These properties, which can be analyzed just-in-time, expose novel optimizations and tiling strategies that SPLAT exploits to generate high-performance implementations for diverse patterns. To demonstrate SPLAT's efficacy, we use it to generate code for various sparse-MHSA models, achieving geomean speedups of 2.05x and 4.05x over hand-written kernels written in triton and TVM respectively on A100 GPUs. Moreover, its interfaces are intuitive and easy to use with existing implementations of MHSA in JAX.
