PermLLM: Learnable Channel Permutation for N:M Sparse Large Language Models
Lancheng Zou, Shuo Yin, Zehua Pei, Tsung-Yi Ho, Farzan Farnia, Bei Yu
TL;DR
This work tackles pruning-induced degradation in N:M sparse large language models by introducing Learnable Channel Permutation (LCP) within PermLLM. LCP relaxes permutation matrices via Sinkhorn normalization to differentiable soft permutations and adopts a block-wise design to keep parameters and computation manageable, enabling end-to-end optimization that minimizes the output discrepancy between dense and pruned models. The framework integrates with existing one-shot pruning methods (e.g., Wanda, RIA) and uses a cosine similarity loss to align dense and sparse outputs, achieving improved accuracy across LLaMA, Qwen, and OPT with a custom CUDA kernel to boost runtime. Key contributions include a practical, pruning-aware permutation learning mechanism, substantial parameter-efficiency through block-wise permutation, and demonstrated superiority over baseline N:M pruning approaches on multiple model families. This approach enhances the viability of deploying high-sparsity LLMs in real-world settings by reducing pruning errors and accelerating channel-permutation operations.
Abstract
Channel permutation is a powerful technique for enhancing the accuracy of N:M sparse models by reordering the channels of weight matrices to prioritize the retention of important weights. However, traditional channel permutation methods rely on handcrafted quality metrics, which often fail to accurately capture the true impact of pruning on model performance. To address this limitation, we propose PermLLM, a novel post-training pruning framework that introduces learnable channel permutation (LCP) for N:M sparsity. LCP leverages Sinkhorn normalization to transform discrete permutation matrices into differentiable soft permutation matrices, enabling end-to-end optimization. Additionally, PermLLM incorporates an efficient block-wise channel permutation strategy, which significantly reduces the number of learnable parameters and computational complexity. PermLLM seamlessly integrates with existing one-shot pruning methods to adaptively optimize channel permutations, effectively mitigating pruning-induced errors. Extensive experiments on the LLaMA series, Qwen, and OPT models demonstrate that PermLLM achieves superior performance in optimizing N:M sparse models. The code is available at https://github.com/lanchengzou/PermLLM.
