Table of Contents
Fetching ...

BLaST: High Performance Inference and Pretraining using BLock Sparse Transformers

Patrik Okanovic, Sameer Deshmukh, Grzegorz Kwasniewski, Yi Zhu, Haruto Fujii, Sakina Fatima, Maciej Besta, Kentaro Katayama, Takumi Honda, Yusuke Nagasaka, Torsten Hoefler

TL;DR

BLaST targets reducing memory movement and hardware cost in Transformer-based models by sparsifying MLP weights with block sparsity. It combines two pruning strategies (P&G and oLLM) with a high-performance block-sparse SpMM kernel implemented in Triton to sustain accuracy while achieving large sparsity. Across eight datasets and eleven models, BLaST achieves up to $95\%$ sparsity with minimal accuracy loss, delivering up to $2.2\times$ end-to-end speedups and up to $4.45\times$ memory-footprint reductions, including substantial gains on Llama 3.2 and larger models, and demonstrates viability for both inference and pretraining workflows at scale.

Abstract

The energy consumption of large-scale ML models is dominated by data movement, shuffling billions of parameters across memory hierarchies and data centers. Sparsification offers a principled way to mitigate these costs by pruning redundant weights and activations, thereby reducing data movement. Effective sparsification to prune redundant parameters is still challenging: existing methods incur significant accuracy degradation, performance overhead, or both. We introduce (Bl)ock (a)nd (S)parse (T)ransformers (BLaST), a general, robust, and reliable method for sparsification, applicable to linear layers in all settings. Our method iteratively sparsifies weight matrices into a block sparsity pattern suitable for efficient sparse matrix-matrix (SpMM) multiplication. BLaST achieves up to 95% sparsity in MLP weights with negligible accuracy loss (majority <2.25%). We show a 2.2x inference speedup for Llama 3.2 with 16 GPUs, and up to 4.45x reduction in inference memory footprint resulting in a 2.9x reduction in GPU setup and operating costs.

BLaST: High Performance Inference and Pretraining using BLock Sparse Transformers

TL;DR

BLaST targets reducing memory movement and hardware cost in Transformer-based models by sparsifying MLP weights with block sparsity. It combines two pruning strategies (P&G and oLLM) with a high-performance block-sparse SpMM kernel implemented in Triton to sustain accuracy while achieving large sparsity. Across eight datasets and eleven models, BLaST achieves up to sparsity with minimal accuracy loss, delivering up to end-to-end speedups and up to memory-footprint reductions, including substantial gains on Llama 3.2 and larger models, and demonstrates viability for both inference and pretraining workflows at scale.

Abstract

The energy consumption of large-scale ML models is dominated by data movement, shuffling billions of parameters across memory hierarchies and data centers. Sparsification offers a principled way to mitigate these costs by pruning redundant weights and activations, thereby reducing data movement. Effective sparsification to prune redundant parameters is still challenging: existing methods incur significant accuracy degradation, performance overhead, or both. We introduce (Bl)ock (a)nd (S)parse (T)ransformers (BLaST), a general, robust, and reliable method for sparsification, applicable to linear layers in all settings. Our method iteratively sparsifies weight matrices into a block sparsity pattern suitable for efficient sparse matrix-matrix (SpMM) multiplication. BLaST achieves up to 95% sparsity in MLP weights with negligible accuracy loss (majority <2.25%). We show a 2.2x inference speedup for Llama 3.2 with 16 GPUs, and up to 4.45x reduction in inference memory footprint resulting in a 2.9x reduction in GPU setup and operating costs.

Paper Structure

This paper contains 31 sections, 7 equations, 12 figures, 5 tables.

Figures (12)

  • Figure 1: BLaST sparsifies MLP weights up to 95% with minimal accuracy loss and speeds up end-to-end Llama 3.2--1B inference by up to 1.6$\times$. For Llama 3.2--405B, sparsity reduces required GPUs by up to 2.9$\times$.
  • Figure 2: Overview of BLaST. Training or fine-tuning alternates dense updates with a pluggable block-pruning stage (instantiated with Blocked Prune-and-Grow (\ref{['sec:blocked-prune-and-grow']}) and oLLM (\ref{['sec:ollm']})) while using a fused, high-performance sparse MLP kernel (\ref{['sec:kernel_section']}) until the target sparsity is reached.
  • Figure 3: The blocked oBERT accumulates weight gradients $G_i$ over many forward and backward iterations to generate the sensitivity scores $S_i$ using the Fisher inverse matrix. $S_i$ is then used for generating a blocked mask $M_i$. The blocked weight matrix $W_{new}$ is generated with an element-wise multiplication of the weight matrix $W_i$ with $M_i$.
  • Figure 4: Overview of BLaST BSpMM. The kernel uses a blocked Compressed Sparse Column (BCSC) layout derived from sparsified neural network weights. BLaST applies block-level, bottom-up 2D parallelism in Triton to maximize GPU utilization.
  • Figure 5: Throughput of BLaST BSpMM compared to state-of-the-art kernels for various combinations of batch size, embedding and sequence length typically found in the networks that we benchmark. BLaST BSpMM can outperform all vendor implementations by a wide margin for all block sizes.
  • ...and 7 more figures