Masked Matrix Multiplication for Emergent Sparsity
Brian Wheatman, Meghana Madhyastha, Randal Burns
TL;DR
MMM addresses the inefficiency of dense GEMMs on transformer-like workloads by exploiting emergent, dual-sided sparsity through runtime masks $M(X)$ and $N(Y)$ and a pattern-driven, code-generated kernel table. The approach uses preprocessing to encode sparsity patterns in $B$ and dynamic code selection to skip zero blocks in $A$, maintaining vectorization with AVX2/AVX-512 and single-pass parallelism. Empirical results show up to 2x speedups and up to 4x fewer instructions over MKL across a wide sparsity range (60–95% zeros), with additional gains on mid-sized matrices and multi-core servers; performance depends on matrix size, sparsity distribution, and architecture. The work demonstrates a practical path to reduce cost, power, and time for sparsity-enabled AI workloads on CPUs and outlines clear directions for GPU extension and broader optimization.
Abstract
Artificial intelligence workloads, especially transformer models, exhibit emergent sparsity in which computations perform selective sparse access to dense data. The workloads are inefficient on hardware designed for dense computations and do not map well onto sparse data representations. We build a vectorized and parallel matrix-multiplication system A X B = C that eliminates unnecessary computations and avoids branches based on a runtime evaluation of sparsity. We use a combination of dynamic code lookup to adapt to the specific sparsity encoded in the B matrix and preprocessing of sparsity maps of the A and B matrices to compute conditional branches once for the whole computation. For a wide range of sparsity, from 60% to 95% zeros, our implementation performs fewer instructions and increases performance when compared with Intel MKL's dense or sparse matrix multiply routines. Benefits can be as large as 2 times speedup and 4 times fewer instructions.
