Table of Contents
Fetching ...

SPLAT: A framework for optimised GPU code-generation for SParse reguLar ATtention

Ahan Gupta, Yueming Yuan, Devansh Jain, Yuhao Ge, David Aponte, Yanqi Zhou, Charith Mendis

TL;DR

SPLAT addresses the bottleneck of quadratic MHSA by introducing affine-compressed-sparse-row (ACSR), a regularity-informed sparse format, together with a just-in-time GPU code-generation pipeline. The framework develops specialized kernels (R-SDDMM and R-SpMM) and a robust tiling strategy (poset tiling) that exploit per-row affine-indices to reduce metadata and memory traffic while maintaining broad coverage of sparse-MHSA patterns. Empirical evaluation shows SPLAT delivering substantial speedups over both vendor libraries (up to $2.46\times$ in R-SDDMM and $2.81\times$ in R-SpMM on representative patterns) and hand-written kernels (up to $4.05\times$ for end-to-end sparse-MHSA on A100 GPUs), with additional gains from optimized data-layout choices and end-to-end code generation. The work demonstrates that exploiting regular sparsity with a specialized format and code generation can unify generality and high performance for sparse-MHSA, enabling scalable, efficient inference for long-context transformers in practice.

Abstract

Multi-head-self-attention (MHSA) mechanisms achieve state-of-the-art (SOTA) performance across natural language processing and vision tasks. However, their quadratic dependence on sequence lengths has bottlenecked inference speeds. To circumvent this bottleneck, researchers have proposed various sparse-MHSA models, where a subset of full attention is computed. Despite their promise, current sparse libraries and compilers do not support high-performance implementations for diverse sparse-MHSA patterns due to the underlying sparse formats they operate on. These formats, which are typically designed for high-performance & scientific computing applications, are either curated for extreme amounts of random sparsity (<1% non-zero values), or specific sparsity patterns. However, the sparsity patterns in sparse-MHSA are moderately sparse (10-50% non-zero values) and varied, resulting in existing sparse-formats trading off generality for performance. We bridge this gap, achieving both generality and performance, by proposing a novel sparse format: affine-compressed-sparse-row (ACSR) and supporting code-generation scheme, SPLAT, that generates high-performance implementations for diverse sparse-MHSA patterns on GPUs. Core to our proposed format and code generation algorithm is the observation that common sparse-MHSA patterns have uniquely regular geometric properties. These properties, which can be analyzed just-in-time, expose novel optimizations and tiling strategies that SPLAT exploits to generate high-performance implementations for diverse patterns. To demonstrate SPLAT's efficacy, we use it to generate code for various sparse-MHSA models, achieving geomean speedups of 2.05x and 4.05x over hand-written kernels written in triton and TVM respectively on A100 GPUs. Moreover, its interfaces are intuitive and easy to use with existing implementations of MHSA in JAX.

SPLAT: A framework for optimised GPU code-generation for SParse reguLar ATtention

TL;DR

SPLAT addresses the bottleneck of quadratic MHSA by introducing affine-compressed-sparse-row (ACSR), a regularity-informed sparse format, together with a just-in-time GPU code-generation pipeline. The framework develops specialized kernels (R-SDDMM and R-SpMM) and a robust tiling strategy (poset tiling) that exploit per-row affine-indices to reduce metadata and memory traffic while maintaining broad coverage of sparse-MHSA patterns. Empirical evaluation shows SPLAT delivering substantial speedups over both vendor libraries (up to in R-SDDMM and in R-SpMM on representative patterns) and hand-written kernels (up to for end-to-end sparse-MHSA on A100 GPUs), with additional gains from optimized data-layout choices and end-to-end code generation. The work demonstrates that exploiting regular sparsity with a specialized format and code generation can unify generality and high performance for sparse-MHSA, enabling scalable, efficient inference for long-context transformers in practice.

Abstract

Multi-head-self-attention (MHSA) mechanisms achieve state-of-the-art (SOTA) performance across natural language processing and vision tasks. However, their quadratic dependence on sequence lengths has bottlenecked inference speeds. To circumvent this bottleneck, researchers have proposed various sparse-MHSA models, where a subset of full attention is computed. Despite their promise, current sparse libraries and compilers do not support high-performance implementations for diverse sparse-MHSA patterns due to the underlying sparse formats they operate on. These formats, which are typically designed for high-performance & scientific computing applications, are either curated for extreme amounts of random sparsity (<1% non-zero values), or specific sparsity patterns. However, the sparsity patterns in sparse-MHSA are moderately sparse (10-50% non-zero values) and varied, resulting in existing sparse-formats trading off generality for performance. We bridge this gap, achieving both generality and performance, by proposing a novel sparse format: affine-compressed-sparse-row (ACSR) and supporting code-generation scheme, SPLAT, that generates high-performance implementations for diverse sparse-MHSA patterns on GPUs. Core to our proposed format and code generation algorithm is the observation that common sparse-MHSA patterns have uniquely regular geometric properties. These properties, which can be analyzed just-in-time, expose novel optimizations and tiling strategies that SPLAT exploits to generate high-performance implementations for diverse patterns. To demonstrate SPLAT's efficacy, we use it to generate code for various sparse-MHSA models, achieving geomean speedups of 2.05x and 4.05x over hand-written kernels written in triton and TVM respectively on A100 GPUs. Moreover, its interfaces are intuitive and easy to use with existing implementations of MHSA in JAX.
Paper Structure (40 sections, 9 theorems, 13 equations, 17 figures, 1 table, 1 algorithm)

This paper contains 40 sections, 9 theorems, 13 equations, 17 figures, 1 table, 1 algorithm.

Key Result

Theorem 1

Given point-set $P$ and arrangement of thread-blocks $TB_{poset} = \{TB_1, TB_2, ..., TB_{\lambda_{poset}}\}$, each of size $m \times n$, generated by algorithm alg:greedy-tile to cover $P$. Let $TB_{opt} = \{TB_1, TB_2, ..., TB_{\lambda_{opt}}\}$ be the arrangement of lowest possible cost to cover Where $l$ is the maximum number of points in a row of the mask. Moreover, for the strided pattern,

Figures (17)

  • Figure 1: Run-time results for a sparse primitive used in sparse-MHSA (R-SpMM) comparing cuSPARSE, cuBLAS and SPLAT. We vary the density of the sparse input across: [0.4, 0.8, 1.6, 3, 6, 12, 24, 44, 75, 100]. The sparse input takes the shape of the blocked pattern (figure \ref{['fig:regular-sparsity']} right).
  • Figure 2: Examples of 3 commonly occurring sparse-MHSA patterns in the literature. Strided (left figure), Windowed (middle figure) longformer, Blocked (right figure) sparse-transformerreformer. Full attention computes all points
  • Figure 3: A comparison between SpMM implementations that use the CSR format (b), and a specialized format (c). (d) and (e) are naive SpMM implementations of $C=AB$, when $A$ is represented as a CSR and the specialized format of (c), respectively.
  • Figure 4: Profile of R-SpMM sparse-primitive implemented in SPLAT, cuBLAS, cuSPARSE and Triton. Matrices are 1024x1024 with sparse matrices in the window format (see figure \ref{['fig:regular-sparsity']} - middle) at 24% density. FFMA is an FP32 fused multiply-add instruction and L2 read is the amount of data-traffic (in GB) from L2 to L1 cache. Lower is better.
  • Figure 5: An overview on SPLAT's inner mechanics and how its just-in-time strategy produces compiled sparse-MHSA kernels for inference.
  • ...and 12 more figures

Theorems & Definitions (16)

  • Definition 1
  • Definition 2
  • Definition 3
  • Definition 4
  • Definition 5
  • Theorem 1
  • Theorem 3
  • Theorem 4
  • Theorem 5
  • Theorem 6
  • ...and 6 more