Table of Contents
Fetching ...

MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models

Gongfan Fang, Hongxu Yin, Saurav Muralidharan, Greg Heinrich, Jeff Pool, Jan Kautz, Pavlo Molchanov, Xinchao Wang

TL;DR

MaskLLM tackles redundancy in large language models by learning semi-structured $N:M$ sparsity patterns via differentiable mask sampling based on Gumbel Softmax. It models a distribution over candidate $N:M$ masks per parameter block and optimizes the language modeling loss with a sparse-weight regularization term, enabling end-to-end pruning on large datasets and transfer across domains through a Mask Prior. Empirically, MaskLLM achieves substantial perplexity gains over one-shot baselines (for example, $PPL=6.72$ vs $10.42$ on Wikitext for LLaMA-2 7B) and up to about $1.4\times$ inference speedups on GPUs, with successful transfer to downstream tasks and vision transformers. These results demonstrate practical benefits for scalable deployment of LLMs with lossless compression in certain domains and robust performance across model families from LLaMA-2 to Nemotron-4.

Abstract

Large Language Models (LLMs) are distinguished by their massive parameter counts, which typically result in significant redundancy. This work introduces MaskLLM, a learnable pruning method that establishes Semi-structured (or ``N:M'') Sparsity in LLMs, aimed at reducing computational overhead during inference. Instead of developing a new importance criterion, MaskLLM explicitly models N:M patterns as a learnable distribution through Gumbel Softmax sampling. This approach facilitates end-to-end training on large-scale datasets and offers two notable advantages: 1) High-quality Masks - our method effectively scales to large datasets and learns accurate masks; 2) Transferability - the probabilistic modeling of mask distribution enables the transfer learning of sparsity across domains or tasks. We assessed MaskLLM using 2:4 sparsity on various LLMs, including LLaMA-2, Nemotron-4, and GPT-3, with sizes ranging from 843M to 15B parameters, and our empirical results show substantial improvements over state-of-the-art methods. For instance, leading approaches achieve a perplexity (PPL) of 10 or greater on Wikitext compared to the dense model's 5.12 PPL, but MaskLLM achieves a significantly lower 6.72 PPL solely by learning the masks with frozen weights. Furthermore, MaskLLM's learnable nature allows customized masks for lossless application of 2:4 sparsity to downstream tasks or domains. Code is available at https://github.com/NVlabs/MaskLLM.

MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models

TL;DR

MaskLLM tackles redundancy in large language models by learning semi-structured sparsity patterns via differentiable mask sampling based on Gumbel Softmax. It models a distribution over candidate masks per parameter block and optimizes the language modeling loss with a sparse-weight regularization term, enabling end-to-end pruning on large datasets and transfer across domains through a Mask Prior. Empirically, MaskLLM achieves substantial perplexity gains over one-shot baselines (for example, vs on Wikitext for LLaMA-2 7B) and up to about inference speedups on GPUs, with successful transfer to downstream tasks and vision transformers. These results demonstrate practical benefits for scalable deployment of LLMs with lossless compression in certain domains and robust performance across model families from LLaMA-2 to Nemotron-4.

Abstract

Large Language Models (LLMs) are distinguished by their massive parameter counts, which typically result in significant redundancy. This work introduces MaskLLM, a learnable pruning method that establishes Semi-structured (or ``N:M'') Sparsity in LLMs, aimed at reducing computational overhead during inference. Instead of developing a new importance criterion, MaskLLM explicitly models N:M patterns as a learnable distribution through Gumbel Softmax sampling. This approach facilitates end-to-end training on large-scale datasets and offers two notable advantages: 1) High-quality Masks - our method effectively scales to large datasets and learns accurate masks; 2) Transferability - the probabilistic modeling of mask distribution enables the transfer learning of sparsity across domains or tasks. We assessed MaskLLM using 2:4 sparsity on various LLMs, including LLaMA-2, Nemotron-4, and GPT-3, with sizes ranging from 843M to 15B parameters, and our empirical results show substantial improvements over state-of-the-art methods. For instance, leading approaches achieve a perplexity (PPL) of 10 or greater on Wikitext compared to the dense model's 5.12 PPL, but MaskLLM achieves a significantly lower 6.72 PPL solely by learning the masks with frozen weights. Furthermore, MaskLLM's learnable nature allows customized masks for lossless application of 2:4 sparsity to downstream tasks or domains. Code is available at https://github.com/NVlabs/MaskLLM.
Paper Structure (39 sections, 9 equations, 8 figures, 15 tables, 1 algorithm)

This paper contains 39 sections, 9 equations, 8 figures, 15 tables, 1 algorithm.

Figures (8)

  • Figure 1: Learnable N:M sparsity for Large Language Models.
  • Figure 2: This work introduces learnable semi-structured sparsity for LLMs. models mask selection as a distribution learning problem, enabling the creation of accurate masks through end-to-end training on large-scale datasets. The learned and general mask can be further transferred to downstream tasks or domains, achieving lossless compression.
  • Figure 3: Drawing a random mask from the learnable distribution with Gumbel Softmax. Each consecutive M parameters are associated with a learnable distribution for candidate masks. All illustrated computations, including Gumbel Softmax, and the weighted averaging are differentiable.
  • Figure 4: Consumed samples vs. PPL on LLaMA-2 7B. requires 128 samples for the prior and outperforms SparseGPT after 1280 samples.
  • Figure 5: (a) The L1 distance of sampled masks between adjacent training steps. (b) The maximum probability of mask distribution, serving as an indicator of convergence. In our method, the randomness of mask sampling is regulated by the scaling factor $\kappa$. A too-small $\kappa$ introduces huge randomness, resulting in slow convergence as shown in (b). And an inappropriately large $\kappa$ will suppress mask exploration and yield zero mask difference throughout the training process in (a).
  • ...and 3 more figures