Enhancing One-shot Pruned Pre-trained Language Models through Sparse-Dense-Sparse Mechanism
Guanchen Li, Xiandong Zhao, Lian Liu, Zeping Li, Dong Li, Lu Tian, Jie He, Ashish Sirasao, Emad Barsoum
TL;DR
The paper tackles the efficiency gap of large pre-trained language models by addressing the performance drop seen in one-shot pruning on compact models. It introduces Sparse-Dense-Sparse (SDS), a three-step framework (initial pruning, re-dense weight reconstruction with sparse regularization, and a second pruning) that optimizes weight distributions to produce pruning-friendly dense predecessors for subsequent pruning. Through layer-wise knowledge distillation and carefully designed regularizations, SDS achieves superior perplexity and zero-shot accuracy compared to SparseGPT and Wanda under the same sparsity, and delivers notable CPU speedups on commodity hardware. The approach demonstrates that reconfiguring weight distributions rather than relying solely on pruning masks can yield robust, hardware-friendly sparse models suitable for real-world deployment.
Abstract
Pre-trained language models (PLMs) are engineered to be robust in contextual understanding and exhibit outstanding performance in various natural language processing tasks. However, their considerable size incurs significant computational and storage costs. Modern pruning strategies employ one-shot techniques to compress PLMs without the need for retraining on task-specific or otherwise general data; however, these approaches often lead to an indispensable reduction in performance. In this paper, we propose SDS, a Sparse-Dense-Sparse pruning framework to enhance the performance of the pruned PLMs from a weight distribution optimization perspective. We outline the pruning process in three steps. Initially, we prune less critical connections in the model using conventional one-shot pruning methods. Next, we reconstruct a dense model featuring a pruning-friendly weight distribution by reactivating pruned connections with sparse regularization. Finally, we perform a second pruning round, yielding a superior pruned model compared to the initial pruning. Experimental results demonstrate that SDS outperforms the state-of-the-art pruning techniques SparseGPT and Wanda under an identical sparsity configuration. For instance, SDS reduces perplexity by 9.13 on Raw-Wikitext2 and improves accuracy by an average of 2.05% across multiple zero-shot benchmarks for OPT-125M with 2:4 sparsity.
