S-STE: Continuous Pruning Function for Efficient 2:4 Sparse Pre-training
Yuezhou Hu, Jun Zhu, Jianfei Chen
TL;DR
S-STE introduces a continuous pruning function for $2:4$ N:M sparsity to enable efficient pre-training of large transformers. By combining a $2:4$-specific soft-thresholding operator with a fixed per-tensor scaling and augmenting gradients via MVUE and FP8 training, the approach avoids the discontinuities that derail STE-based methods and achieves competitive accuracy across translation, vision, and language modeling tasks. The work further demonstrates practical gains through ablations and reports up to notable speedups on hardware, while acknowledging hardware-dependent acceleration on state-of-the-art GPUs. Overall, S-STE provides a principled path to scalable, sparse pre-training with improved optimization stability and hardware compatibility.
Abstract
Training deep neural networks (DNNs) is costly. Fortunately, Nvidia Ampere and Hopper GPUs can accelerate matrix multiplications twice as fast as a dense equivalent by implementing 2:4 sparsity. However, previous STE-based 2:4 pre-training methods (e.g. STE with hard-thresholding, SR-STE) suffer from optimization difficulties because of discontinuous pruning function. In this study, we comprehensively analyse the bottleneck of traditional N:M sparse training and recognize three drawbacks with discontinuity: incorrect descending direction, inability to predict the amount of descent and sparse mask oscillation. In light of this, we propose S-STE, a simple yet powerful 2:4 training method that contains two parts: to continuously project weights to be 2:4 sparse, and to rescale sparse weights with a per-tensor fixed scaling factor. Besides, we adopt minimum-variance unbiased estimation for activation gradient and FP8 quantization for whole process. Results show that our method surpasses previous 2:4 pre-training recipes and is comparable even with full parameter models. Our toolkit is available at https://github.com/huyz2023/2by4-pretrain.
