Table of Contents
Fetching ...

PermLLM: Learnable Channel Permutation for N:M Sparse Large Language Models

Lancheng Zou, Shuo Yin, Zehua Pei, Tsung-Yi Ho, Farzan Farnia, Bei Yu

TL;DR

This work tackles pruning-induced degradation in N:M sparse large language models by introducing Learnable Channel Permutation (LCP) within PermLLM. LCP relaxes permutation matrices via Sinkhorn normalization to differentiable soft permutations and adopts a block-wise design to keep parameters and computation manageable, enabling end-to-end optimization that minimizes the output discrepancy between dense and pruned models. The framework integrates with existing one-shot pruning methods (e.g., Wanda, RIA) and uses a cosine similarity loss to align dense and sparse outputs, achieving improved accuracy across LLaMA, Qwen, and OPT with a custom CUDA kernel to boost runtime. Key contributions include a practical, pruning-aware permutation learning mechanism, substantial parameter-efficiency through block-wise permutation, and demonstrated superiority over baseline N:M pruning approaches on multiple model families. This approach enhances the viability of deploying high-sparsity LLMs in real-world settings by reducing pruning errors and accelerating channel-permutation operations.

Abstract

Channel permutation is a powerful technique for enhancing the accuracy of N:M sparse models by reordering the channels of weight matrices to prioritize the retention of important weights. However, traditional channel permutation methods rely on handcrafted quality metrics, which often fail to accurately capture the true impact of pruning on model performance. To address this limitation, we propose PermLLM, a novel post-training pruning framework that introduces learnable channel permutation (LCP) for N:M sparsity. LCP leverages Sinkhorn normalization to transform discrete permutation matrices into differentiable soft permutation matrices, enabling end-to-end optimization. Additionally, PermLLM incorporates an efficient block-wise channel permutation strategy, which significantly reduces the number of learnable parameters and computational complexity. PermLLM seamlessly integrates with existing one-shot pruning methods to adaptively optimize channel permutations, effectively mitigating pruning-induced errors. Extensive experiments on the LLaMA series, Qwen, and OPT models demonstrate that PermLLM achieves superior performance in optimizing N:M sparse models. The code is available at https://github.com/lanchengzou/PermLLM.

PermLLM: Learnable Channel Permutation for N:M Sparse Large Language Models

TL;DR

This work tackles pruning-induced degradation in N:M sparse large language models by introducing Learnable Channel Permutation (LCP) within PermLLM. LCP relaxes permutation matrices via Sinkhorn normalization to differentiable soft permutations and adopts a block-wise design to keep parameters and computation manageable, enabling end-to-end optimization that minimizes the output discrepancy between dense and pruned models. The framework integrates with existing one-shot pruning methods (e.g., Wanda, RIA) and uses a cosine similarity loss to align dense and sparse outputs, achieving improved accuracy across LLaMA, Qwen, and OPT with a custom CUDA kernel to boost runtime. Key contributions include a practical, pruning-aware permutation learning mechanism, substantial parameter-efficiency through block-wise permutation, and demonstrated superiority over baseline N:M pruning approaches on multiple model families. This approach enhances the viability of deploying high-sparsity LLMs in real-world settings by reducing pruning errors and accelerating channel-permutation operations.

Abstract

Channel permutation is a powerful technique for enhancing the accuracy of N:M sparse models by reordering the channels of weight matrices to prioritize the retention of important weights. However, traditional channel permutation methods rely on handcrafted quality metrics, which often fail to accurately capture the true impact of pruning on model performance. To address this limitation, we propose PermLLM, a novel post-training pruning framework that introduces learnable channel permutation (LCP) for N:M sparsity. LCP leverages Sinkhorn normalization to transform discrete permutation matrices into differentiable soft permutation matrices, enabling end-to-end optimization. Additionally, PermLLM incorporates an efficient block-wise channel permutation strategy, which significantly reduces the number of learnable parameters and computational complexity. PermLLM seamlessly integrates with existing one-shot pruning methods to adaptively optimize channel permutations, effectively mitigating pruning-induced errors. Extensive experiments on the LLaMA series, Qwen, and OPT models demonstrate that PermLLM achieves superior performance in optimizing N:M sparse models. The code is available at https://github.com/lanchengzou/PermLLM.

Paper Structure

This paper contains 19 sections, 10 equations, 3 figures, 8 tables.

Figures (3)

  • Figure 1: Effects of different channel permutation strategies on the outputs. Channel order is in purple. We use magnitude pruning nips15_pruning for 2:4 sparsity in this example. Score $S$ denotes the sum of retained weight importance, which is used as the quality metric for channel permutation nips21_cpiclr24_ria. Loss is the mean square error between the original output $\mathbf{y}$ and the output of the pruned one. The output loss of direct 2:4 sparsity (i.e., without channel permutation) is 12.375. The results demonstrate that channel permutation which maximizes the score may lead to performance degradation.
  • Figure 2: Illustration of learnable channel permutation with different granularity: (a) full matrix LCP; (b) block-wise LCP.
  • Figure 3: Visualization of mask obtained by different pruning methods. The blue part means the pruned weights and the white part is the retained weights