A Simple Linear Patch Revives Layer-Pruned Large Language Models
Xinrui Chen, Haoli Bai, Tao Yuan, Ruikang Liu, Kang Zhao, Xianzhi Yu, Lu Hou, Tian Guan, Yonghong He, Chun Yuan
TL;DR
This paper tackles the degradation observed when pruning Transformer layers in large language models by identifying a mismatch in activation magnitudes across the pruning interface. It introduces LinearPatch, a plug-in that fuses a Hadamard rotation with channel-wise scaling into a real symmetric patch matrix $P=oldsymbol{H}oldsymbol{D}oldsymbol{H}^{ op}$, enabling a single GEMM insertion and minimal inference overhead. Complementing this, memory-efficient offline distillation with Top-$K$ logits guides a lightweight fine-tuning of the patch to recover performance with only 5K samples and around 30 minutes on a single GPU. Across multiple models and benchmarks (QA, PPL, MMLU), LinearPatch consistently outperforms training-free baselines and yields substantial gains in post-training recovery, achieving up to 95.16% retained performance on LLaMA-3-8B with distillation. The method is simple, hardware-friendly, and broadly applicable to diverse pruning configurations, offering a practical route to deploy leaner LLMs without major accuracy trade-offs.
Abstract
Layer pruning has emerged as a widely used technique for compressing large language models (LLMs). However, existing layer pruning approaches often incur substantial performance degradation. We identify the majority of this degradation to a single yet previously overlooked issue: \textit{the mismatch of activation magnitudes at the pruning interface}. The pre-interface activations exhibit significantly different scales from the post-interface ones, causing the distributional shift as it propagates through the remaining layers. To address this issue, we introduce \textsc{LinearPatch}, a lightweight and plug-and-play technique that fuses two operations into one matrix multiply at the pruning interface: (i) a Hadamard transformation that suppresses massive outliers at particular tokens and (ii) a channel-wise scaling that aligns activation statistics. On LLaMA-3-8B, \textsc{LinearPatch} preserves up to \textbf{94.15\%} of the original model's performance when pruning 5 out of 32 layers, outperforming the previous state of the art by \textbf{4\%}. The patch can be further refined with 5K unlabeled samples via memory-efficient offline distillation, pushing the retention to 95.16\% within only 30 minutes on a single GPU. Code is available at https://github.com/chenxinrui-tsinghua/LinearPatch.
