Table of Contents
Fetching ...

Enhancing One-shot Pruned Pre-trained Language Models through Sparse-Dense-Sparse Mechanism

Guanchen Li, Xiandong Zhao, Lian Liu, Zeping Li, Dong Li, Lu Tian, Jie He, Ashish Sirasao, Emad Barsoum

TL;DR

The paper tackles the efficiency gap of large pre-trained language models by addressing the performance drop seen in one-shot pruning on compact models. It introduces Sparse-Dense-Sparse (SDS), a three-step framework (initial pruning, re-dense weight reconstruction with sparse regularization, and a second pruning) that optimizes weight distributions to produce pruning-friendly dense predecessors for subsequent pruning. Through layer-wise knowledge distillation and carefully designed regularizations, SDS achieves superior perplexity and zero-shot accuracy compared to SparseGPT and Wanda under the same sparsity, and delivers notable CPU speedups on commodity hardware. The approach demonstrates that reconfiguring weight distributions rather than relying solely on pruning masks can yield robust, hardware-friendly sparse models suitable for real-world deployment.

Abstract

Pre-trained language models (PLMs) are engineered to be robust in contextual understanding and exhibit outstanding performance in various natural language processing tasks. However, their considerable size incurs significant computational and storage costs. Modern pruning strategies employ one-shot techniques to compress PLMs without the need for retraining on task-specific or otherwise general data; however, these approaches often lead to an indispensable reduction in performance. In this paper, we propose SDS, a Sparse-Dense-Sparse pruning framework to enhance the performance of the pruned PLMs from a weight distribution optimization perspective. We outline the pruning process in three steps. Initially, we prune less critical connections in the model using conventional one-shot pruning methods. Next, we reconstruct a dense model featuring a pruning-friendly weight distribution by reactivating pruned connections with sparse regularization. Finally, we perform a second pruning round, yielding a superior pruned model compared to the initial pruning. Experimental results demonstrate that SDS outperforms the state-of-the-art pruning techniques SparseGPT and Wanda under an identical sparsity configuration. For instance, SDS reduces perplexity by 9.13 on Raw-Wikitext2 and improves accuracy by an average of 2.05% across multiple zero-shot benchmarks for OPT-125M with 2:4 sparsity.

Enhancing One-shot Pruned Pre-trained Language Models through Sparse-Dense-Sparse Mechanism

TL;DR

The paper tackles the efficiency gap of large pre-trained language models by addressing the performance drop seen in one-shot pruning on compact models. It introduces Sparse-Dense-Sparse (SDS), a three-step framework (initial pruning, re-dense weight reconstruction with sparse regularization, and a second pruning) that optimizes weight distributions to produce pruning-friendly dense predecessors for subsequent pruning. Through layer-wise knowledge distillation and carefully designed regularizations, SDS achieves superior perplexity and zero-shot accuracy compared to SparseGPT and Wanda under the same sparsity, and delivers notable CPU speedups on commodity hardware. The approach demonstrates that reconfiguring weight distributions rather than relying solely on pruning masks can yield robust, hardware-friendly sparse models suitable for real-world deployment.

Abstract

Pre-trained language models (PLMs) are engineered to be robust in contextual understanding and exhibit outstanding performance in various natural language processing tasks. However, their considerable size incurs significant computational and storage costs. Modern pruning strategies employ one-shot techniques to compress PLMs without the need for retraining on task-specific or otherwise general data; however, these approaches often lead to an indispensable reduction in performance. In this paper, we propose SDS, a Sparse-Dense-Sparse pruning framework to enhance the performance of the pruned PLMs from a weight distribution optimization perspective. We outline the pruning process in three steps. Initially, we prune less critical connections in the model using conventional one-shot pruning methods. Next, we reconstruct a dense model featuring a pruning-friendly weight distribution by reactivating pruned connections with sparse regularization. Finally, we perform a second pruning round, yielding a superior pruned model compared to the initial pruning. Experimental results demonstrate that SDS outperforms the state-of-the-art pruning techniques SparseGPT and Wanda under an identical sparsity configuration. For instance, SDS reduces perplexity by 9.13 on Raw-Wikitext2 and improves accuracy by an average of 2.05% across multiple zero-shot benchmarks for OPT-125M with 2:4 sparsity.
Paper Structure (17 sections, 4 equations, 4 figures, 8 tables)

This paper contains 17 sections, 4 equations, 4 figures, 8 tables.

Figures (4)

  • Figure 1: An Overview of the Steps of the SDS Framework, divided into initial pruning, re-dense weight reconstruction, and a second round of pruning. The upper figure shows the weight distribution variation within the SDS framework, and the lower figure demonstrates the variation in weight connections. The weights are extracted from the FFN in the $12$-th transformer block of OPT-125M, with 50% sparsity configuration. Initially, the dense weights follow a Gaussian distribution. After being pruned by SparseGPT, a concentrated, bimodal distribution emerges (zero values are omitted in sparse weight distributions for better clarity). Followed by connection reconstruction with sparse regularization, a three-peaked distribution materializes. Finally, the second pruning round attenuates the sharp peaks, resulting in a softer bimodal distribution. Perplexity (PPL) is evaluated on Raw-WikiText2. The second pruned model achieves a lower perplexity than the initially pruned one.
  • Figure 2: Four Data Selection Paradigms in Weight Adjustment. Straight lines represent forward propagation, and dashed lines represent knowledge distillation.
  • Figure 3: Changes in Distributions During Optimization of Pruned PLMs. The distribution observations are from the last layer of OPT-125m with 50% pruning. (a) represents the process of first pruning the model by magnitude (absmin) magnitude and then optimizing the pruned model using SD-data. (b) represents the SDS w KD in the ablation study (cf., \ref{['sec:3-3']}). (c) represents the SDS. (d) represents the SDS w KD. (e) represents the SDS w SD. (f) represents the SDS w SD and with Wanda as the pruning method. Zero values are omitted in sparse weight distributions for better clarity.
  • Figure 4: Sparsity vs. Perplexity in OPTs.