Table of Contents
Fetching ...

MaskPrune: Mask-based LLM Pruning for Layer-wise Uniform Structures

Jiayu Qin, Jianchao Tan, Kefeng Zhang, Xunliang Cai, Wei Wang

TL;DR

MaskPrune addresses the challenge of pruning large language models while preserving a uniform inter-layer structure that benefits deployment and continued training. It formulates pruning as a minimax optimization problem that jointly learns pruning masks $m$ and target dimensions $s$ under sparsity and resource constraints, using proximal updates for the non-differentiable sparsity term and a straight-through estimator for gradient flow. The method integrates LoRA adapters and distillation losses to mitigate accuracy loss during pruning, yielding a final model with uniform structure and competitive zero-shot and perplexity performance across LLaMA-family models. Empirical results show MaskPrune often surpasses state-of-the-art baselines, especially at higher sparsities, while maintaining layer uniformity that facilitates efficient inference and training continuation.

Abstract

The remarkable performance of large language models (LLMs) in various language tasks has attracted considerable attention. However, the ever-increasing size of these models presents growing challenges for deployment and inference. Structured pruning, an effective model compression technique, is gaining increasing attention due to its ability to enhance inference efficiency. Nevertheless, most previous optimization-based structured pruning methods sacrifice the uniform structure across layers for greater flexibility to maintain performance. The heterogeneous structure hinders the effective utilization of off-the-shelf inference acceleration techniques and impedes efficient configuration for continued training. To address this issue, we propose a novel masking learning paradigm based on minimax optimization to obtain the uniform pruned structure by optimizing the masks under sparsity regularization. Extensive experimental results demonstrate that our method can maintain high performance while ensuring the uniformity of the pruned model structure, thereby outperforming existing SOTA methods.

MaskPrune: Mask-based LLM Pruning for Layer-wise Uniform Structures

TL;DR

MaskPrune addresses the challenge of pruning large language models while preserving a uniform inter-layer structure that benefits deployment and continued training. It formulates pruning as a minimax optimization problem that jointly learns pruning masks and target dimensions under sparsity and resource constraints, using proximal updates for the non-differentiable sparsity term and a straight-through estimator for gradient flow. The method integrates LoRA adapters and distillation losses to mitigate accuracy loss during pruning, yielding a final model with uniform structure and competitive zero-shot and perplexity performance across LLaMA-family models. Empirical results show MaskPrune often surpasses state-of-the-art baselines, especially at higher sparsities, while maintaining layer uniformity that facilitates efficient inference and training continuation.

Abstract

The remarkable performance of large language models (LLMs) in various language tasks has attracted considerable attention. However, the ever-increasing size of these models presents growing challenges for deployment and inference. Structured pruning, an effective model compression technique, is gaining increasing attention due to its ability to enhance inference efficiency. Nevertheless, most previous optimization-based structured pruning methods sacrifice the uniform structure across layers for greater flexibility to maintain performance. The heterogeneous structure hinders the effective utilization of off-the-shelf inference acceleration techniques and impedes efficient configuration for continued training. To address this issue, we propose a novel masking learning paradigm based on minimax optimization to obtain the uniform pruned structure by optimizing the masks under sparsity regularization. Extensive experimental results demonstrate that our method can maintain high performance while ensuring the uniformity of the pruned model structure, thereby outperforming existing SOTA methods.

Paper Structure

This paper contains 25 sections, 20 equations, 10 figures, 4 tables.

Figures (10)

  • Figure 1: Compresso/NutePrune results in heterogeneous inter-layer structures, whereas MaskPrune achieves uniform inter-layer structures, which is friendly to inference deployment and continual training.
  • Figure 2: Overall framework of MaskPrune. We optimize mask values through proximal gradient updates to identify the optimal pruning structure while simultaneously fine-tuning other parameters.
  • Figure 3: The number of heads and FFN intermediate dimensions retained after pruning Llama-7B to a sparsity of 50%
  • Figure 4: Zero-shot performance and actual sparsity of the Llama-7B model at 50% sparsity under different decay rates.
  • Figure 5: Training loss under different optimization intervals
  • ...and 5 more figures