Table of Contents
Fetching ...

Accelerating a Triton Fused Kernel for W4A16 Quantized Inference with SplitK work decomposition

Adnan Hoque, Less Wright, Chih-Chieh Yang, Mudhakar Srivatsa, Raghu Ganti

TL;DR

Problem: memory-bound quantized inference with skinny matrices in foundation-model workloads. Approach: a fused Triton kernel that performs dequantization and GEMM using a SplitK decomposition with atomic reductions. Contributions: demonstration of significant speedups—around 65% on A100 and 124% on H100 on llama-style shapes, with peaks up to 295%—and analysis attributing gains to finer-grained work distribution and higher memory throughput. Impact: enables faster W4A16 inference in large language models and provides groundwork for future StreamK refinements.

Abstract

We propose an implementation of an efficient fused matrix multiplication kernel for W4A16 quantized inference, where we perform dequantization and GEMM in a fused kernel using a SplitK work decomposition. Our implementation shows improvement for the type of skinny matrix-matrix multiplications found in foundation model inference workloads. In particular, this paper surveys the type of matrix multiplication between a skinny activation matrix and a square weight matrix. Our results show an average of 65% speed improvement on A100, and an average of 124% speed improvement on H100 (with a peak of 295%) for a range of matrix dimensions including those found in a llama-style model, where m < n = k.

Accelerating a Triton Fused Kernel for W4A16 Quantized Inference with SplitK work decomposition

TL;DR

Problem: memory-bound quantized inference with skinny matrices in foundation-model workloads. Approach: a fused Triton kernel that performs dequantization and GEMM using a SplitK decomposition with atomic reductions. Contributions: demonstration of significant speedups—around 65% on A100 and 124% on H100 on llama-style shapes, with peaks up to 295%—and analysis attributing gains to finer-grained work distribution and higher memory throughput. Impact: enables faster W4A16 inference in large language models and provides groundwork for future StreamK refinements.

Abstract

We propose an implementation of an efficient fused matrix multiplication kernel for W4A16 quantized inference, where we perform dequantization and GEMM in a fused kernel using a SplitK work decomposition. Our implementation shows improvement for the type of skinny matrix-matrix multiplications found in foundation model inference workloads. In particular, this paper surveys the type of matrix multiplication between a skinny activation matrix and a square weight matrix. Our results show an average of 65% speed improvement on A100, and an average of 124% speed improvement on H100 (with a peak of 295%) for a range of matrix dimensions including those found in a llama-style model, where m < n = k.
Paper Structure (12 sections, 12 figures, 9 tables, 1 algorithm)

This paper contains 12 sections, 12 figures, 9 tables, 1 algorithm.

Figures (12)

  • Figure 1: SplitK Thread Block Level
  • Figure 2: Data Parallel Thread Block Level
  • Figure 3: SplitK vs Data Parallel TFLOPS A100 40GB
  • Figure 4: SplitK vs Data Parallel TFLOPS A100 80GB
  • Figure 5: SplitK vs Data Parallel TFLOPS H100
  • ...and 7 more figures