MaskPrune: Mask-based LLM Pruning for Layer-wise Uniform Structures
Jiayu Qin, Jianchao Tan, Kefeng Zhang, Xunliang Cai, Wei Wang
TL;DR
MaskPrune addresses the challenge of pruning large language models while preserving a uniform inter-layer structure that benefits deployment and continued training. It formulates pruning as a minimax optimization problem that jointly learns pruning masks $m$ and target dimensions $s$ under sparsity and resource constraints, using proximal updates for the non-differentiable sparsity term and a straight-through estimator for gradient flow. The method integrates LoRA adapters and distillation losses to mitigate accuracy loss during pruning, yielding a final model with uniform structure and competitive zero-shot and perplexity performance across LLaMA-family models. Empirical results show MaskPrune often surpasses state-of-the-art baselines, especially at higher sparsities, while maintaining layer uniformity that facilitates efficient inference and training continuation.
Abstract
The remarkable performance of large language models (LLMs) in various language tasks has attracted considerable attention. However, the ever-increasing size of these models presents growing challenges for deployment and inference. Structured pruning, an effective model compression technique, is gaining increasing attention due to its ability to enhance inference efficiency. Nevertheless, most previous optimization-based structured pruning methods sacrifice the uniform structure across layers for greater flexibility to maintain performance. The heterogeneous structure hinders the effective utilization of off-the-shelf inference acceleration techniques and impedes efficient configuration for continued training. To address this issue, we propose a novel masking learning paradigm based on minimax optimization to obtain the uniform pruned structure by optimizing the masks under sparsity regularization. Extensive experimental results demonstrate that our method can maintain high performance while ensuring the uniformity of the pruned model structure, thereby outperforming existing SOTA methods.
