Table of Contents
Fetching ...

Accelerating Transformer Inference and Training with 2:4 Activation Sparsity

Daniel Haziza, Timothy Chou, Dhruv Choudhary, Luca Wehrstedt, Francisco Massa, Jiecao Yu, Geonhwa Jeong, Supriya Rao, Patrick Labatut, Jesse Cai

TL;DR

The paper addresses the computational bottlenecks of Transformer-scale LLM training and inference by applying $2:4$ activation sparsity to Squared-ReLU FFNs, leveraging intrinsic sparsity to accelerate computations without accuracy loss. It develops a practical FP8-based, three-kernel FFN path and introduces stability-focused optimizations, including a 95/5 feature split and token permutation, to enable effective backward sparsity. Empirical results show minimal accuracy degradation in LLM pretraining and kernel-level speedups up to 1.3x for FFNs (forward/backward) and ~30% for forward passes, highlighting a viable route to accelerate large-scale models. The work provides a scalable framework for integrating activation sparsity with hardware-accelerated 2:4 GEMMs, informing future sparse transformer designs and larger-model deployments.

Abstract

In this paper, we demonstrate how to leverage 2:4 sparsity, a popular hardware-accelerated GPU sparsity pattern, to activations to accelerate large language model training and inference. Crucially we exploit the intrinsic sparsity found in Squared-ReLU activations to provide this acceleration with no accuracy loss. Our approach achieves up to 1.3x faster Feed Forward Network (FFNs) in both the forwards and backwards pass. This work highlights the potential for sparsity to play a key role in accelerating large language model training and inference.

Accelerating Transformer Inference and Training with 2:4 Activation Sparsity

TL;DR

The paper addresses the computational bottlenecks of Transformer-scale LLM training and inference by applying activation sparsity to Squared-ReLU FFNs, leveraging intrinsic sparsity to accelerate computations without accuracy loss. It develops a practical FP8-based, three-kernel FFN path and introduces stability-focused optimizations, including a 95/5 feature split and token permutation, to enable effective backward sparsity. Empirical results show minimal accuracy degradation in LLM pretraining and kernel-level speedups up to 1.3x for FFNs (forward/backward) and ~30% for forward passes, highlighting a viable route to accelerate large-scale models. The work provides a scalable framework for integrating activation sparsity with hardware-accelerated 2:4 GEMMs, informing future sparse transformer designs and larger-model deployments.

Abstract

In this paper, we demonstrate how to leverage 2:4 sparsity, a popular hardware-accelerated GPU sparsity pattern, to activations to accelerate large language model training and inference. Crucially we exploit the intrinsic sparsity found in Squared-ReLU activations to provide this acceleration with no accuracy loss. Our approach achieves up to 1.3x faster Feed Forward Network (FFNs) in both the forwards and backwards pass. This work highlights the potential for sparsity to play a key role in accelerating large language model training and inference.

Paper Structure

This paper contains 4 sections, 6 figures, 2 tables.

Figures (6)

  • Figure 1: We plot the sparsity level progression of different FFN layers over a training run. Replacing SwiGLU with Squared-ReLU does not impact the model accuracy (See Table\ref{['accuraccy-results']}), and makes the FFN activations highly sparse during training & inference (84-98% sparse for this model).
  • Figure 2: Given an activation tensor A of shape $[seqlen, features]$ (Figure 2a), we can accelerate the computation of $AB$ if $A$ is 2:4-sparse token-wise (Figure 2b), or $A^TB$ if $A$ is 2:4-sparse feature-wise (Figure 2c). To respect the 2:4 sparsity constraint, it is possible that some values need to be dropped when the tensor A is sparsified (in red). The more sparse the original tensor is, the less likely we are to drop values - if no values are dropped, the calculation is exact.
  • Figure 3: Pseudo-code for our proposed replacement FP8 Squared-ReLU FFN.
  • Figure 4: The compute graph for a Squared-ReLU FFN during training. As $ReLU^2(y_1)$ is highly sparse, any matrix multiplication involving $y_2$ or $dy_1$ can be accelerated. In total, 4 matrix multiplications out of 6 can be accelerated, when considering both the forward and backward passes.
  • Figure 5: Our kernels can accelerate the FFN forward pass up to 30%, depending on the model dimension and batch size. Larger models or larger batch sizes lead to higher speedups. All FP8 matrix multiplications are done with row-wise scaling, to match the baseline training recipe.
  • ...and 1 more figures