Table of Contents
Fetching ...

PAT: Pruning-Aware Tuning for Large Language Models

Yijiang Liu, Huanrui Yang, Youxin Chen, Rongyu Zhang, Miao Wang, Yuan Du, Li Du

TL;DR

Pruning-Aware Tuning (PAT) introduces a unified framework to simultaneously prune and fine-tune large language models. By inserting Hybrid Sparsification Modules between Attention and FFN, and controlling pruning with a single Unified Sparsification Mask plus an efficient Hybrid-Identity-Operator, PAT achieves substantial parameter reduction with minimal accuracy loss. Identity Loss further stabilizes training by decoupling rotation and scaling within the sparsification modules. Across multiple LLMs (notably Llama2-7B) and pruning ratios, PAT delivers both speedups (up to $1.33\times$) and competitive or superior accuracy compared to LoRA and other baselines, while preserving downstream task performance.

Abstract

Large language models (LLMs) excel in language tasks, especially with supervised fine-tuning after pre-training. However, their substantial memory and computational requirements hinder practical applications. Structural pruning, which reduces less significant weight dimensions, is one solution. Yet, traditional post-hoc pruning often leads to significant performance loss, with limited recovery from further fine-tuning due to reduced capacity. Since the model fine-tuning refines the general and chaotic knowledge in pre-trained models, we aim to incorporate structural pruning with the fine-tuning, and propose the Pruning-Aware Tuning (PAT) paradigm to eliminate model redundancy while preserving the model performance to the maximum extend. Specifically, we insert the innovative Hybrid Sparsification Modules (HSMs) between the Attention and FFN components to accordingly sparsify the upstream and downstream linear modules. The HSM comprises a lightweight operator and a globally shared trainable mask. The lightweight operator maintains a training overhead comparable to that of LoRA, while the trainable mask unifies the channels to be sparsified, ensuring structural pruning. Additionally, we propose the Identity Loss which decouples the transformation and scaling properties of the HSMs to enhance training robustness. Extensive experiments demonstrate that PAT excels in both performance and efficiency. For example, our Llama2-7b model with a 25\% pruning ratio achieves 1.33$\times$ speedup while outperforming the LoRA-finetuned model by up to 1.26\% in accuracy with a similar training cost. Code: https://github.com/kriskrisliu/PAT_Pruning-Aware-Tuning

PAT: Pruning-Aware Tuning for Large Language Models

TL;DR

Pruning-Aware Tuning (PAT) introduces a unified framework to simultaneously prune and fine-tune large language models. By inserting Hybrid Sparsification Modules between Attention and FFN, and controlling pruning with a single Unified Sparsification Mask plus an efficient Hybrid-Identity-Operator, PAT achieves substantial parameter reduction with minimal accuracy loss. Identity Loss further stabilizes training by decoupling rotation and scaling within the sparsification modules. Across multiple LLMs (notably Llama2-7B) and pruning ratios, PAT delivers both speedups (up to ) and competitive or superior accuracy compared to LoRA and other baselines, while preserving downstream task performance.

Abstract

Large language models (LLMs) excel in language tasks, especially with supervised fine-tuning after pre-training. However, their substantial memory and computational requirements hinder practical applications. Structural pruning, which reduces less significant weight dimensions, is one solution. Yet, traditional post-hoc pruning often leads to significant performance loss, with limited recovery from further fine-tuning due to reduced capacity. Since the model fine-tuning refines the general and chaotic knowledge in pre-trained models, we aim to incorporate structural pruning with the fine-tuning, and propose the Pruning-Aware Tuning (PAT) paradigm to eliminate model redundancy while preserving the model performance to the maximum extend. Specifically, we insert the innovative Hybrid Sparsification Modules (HSMs) between the Attention and FFN components to accordingly sparsify the upstream and downstream linear modules. The HSM comprises a lightweight operator and a globally shared trainable mask. The lightweight operator maintains a training overhead comparable to that of LoRA, while the trainable mask unifies the channels to be sparsified, ensuring structural pruning. Additionally, we propose the Identity Loss which decouples the transformation and scaling properties of the HSMs to enhance training robustness. Extensive experiments demonstrate that PAT excels in both performance and efficiency. For example, our Llama2-7b model with a 25\% pruning ratio achieves 1.33 speedup while outperforming the LoRA-finetuned model by up to 1.26\% in accuracy with a similar training cost. Code: https://github.com/kriskrisliu/PAT_Pruning-Aware-Tuning
Paper Structure (27 sections, 12 equations, 6 figures, 9 tables)

This paper contains 27 sections, 12 equations, 6 figures, 9 tables.

Figures (6)

  • Figure 1: Comparison of zero-shot accuracy averaged on downstream tasks. Various pruning methods at a 25% pruning ratio, as well as the unpruned LoRA, are employed. Our PAT (red) notably outperforms LLM-Pruner and SliceGPT, and is comparable to LoRA (blue), surpassing LoRA by 1.26% on the Llama2-7B model.
  • Figure 2: Framework of our Pruning-Aware Tuning (PAT), featuring Hybrid Sparsification Modules (HSMs) positioned between the Attention and Feed-Forward Network (FFN) components. Each HSM includes a Hybrid-Identity-Operator (HIO) and a globally shared trainable mask. At training stage, the mask values will be updated until convergence. At inference stage, the pruned HSMs and the upstream linear layers will be merged, and the downstream layers which receive inputs with zero-valued channels will be pruned accordingly.
  • Figure 3: The differentiable gating function $\mathcal{G}(\cdot)$.
  • Figure 4: The training efficiency and the accuracy comparison for Llama2 7B. Our PAT results are represented as "HIO-M, LoRA-N", where M and N denote the rank value in the HIO and the LoRA, respectively. The LoRA results are "LoRA-N".
  • Figure 5: The VRAM usage and the evaluation accuracy of Llama2 models under various pruning ratios.
  • ...and 1 more figures