Table of Contents
Fetching ...

MARLIN: Mixed-Precision Auto-Regressive Parallel Inference on Large Language Models

Elias Frantar, Roberto L. Castro, Jiale Chen, Torsten Hoefler, Dan Alistarh

TL;DR

MARLIN introduces a family of mixed-precision batched inference kernels for large language models that exploit 4-bit weight quantization on Ampere GPUs to dramatically reduce memory movement while maintaining compute efficiency. The core idea is to align kernel design with the hardware hierarchy (SM, Warp, Tensor Core) and to aggressively pipeline loads, dequantization, and affine scaling, achieving near-peak performance for many practical batch sizes. An extension, Sparse-MARLIN, adds 2:4 sparsity to the weight representation, delivering additional speedups on suitable Tensor Cores. End-to-end evaluation with vLLM demonstrates up to 2.8× speedups in generation time for common models, with Sparse-MARLIN offering up to about 3.3×, indicating strong practical impact for multi-user serving scenarios and realistic deployment. The work also provides GPTQ-friendly modifications and releases the code openly, enabling broader adoption and extension to other compression schemes.

Abstract

As inference on Large Language Models (LLMs) emerges as an important workload in machine learning applications, weight quantization has become a standard technique for efficient GPU deployment. Quantization not only reduces model size, but has also been shown to yield substantial speedups for single-user inference, due to reduced memory movement, with low accuracy impact. Yet, it remains open whether speedups are achievable also in \emph{batched} settings with multiple parallel clients, which are highly relevant for practical serving. It is unclear whether GPU kernels can be designed to remain practically memory-bound, while supporting the substantially increased compute requirements of batched workloads. This paper resolves this question positively by describing the design of Mixed-precision Auto-Regressive LINear kernels, called MARLIN. Concretely, given a model whose weights are compressed via quantization to, e.g., 4 bits per element, MARLIN shows that batchsizes up to 16-32 can be supported with close to maximum ($4\times$) quantization speedup, and larger batchsizes up to 64-128 with gradually decreasing, but still significant, acceleration. MARLIN accomplishes this via a combination of techniques, such as asynchronous memory access, complex task scheduling and pipelining, and bespoke quantization support. Our experiments show that MARLIN's near-optimal performance on individual LLM layers across different scenarios can also lead to end-to-end LLM inference speedups (of up to $2.8\times$) when integrated with the popular vLLM serving engine. Finally, MARLIN is extensible to further compression techniques, like NVIDIA 2:4 sparsity, leading to additional speedups.

MARLIN: Mixed-Precision Auto-Regressive Parallel Inference on Large Language Models

TL;DR

MARLIN introduces a family of mixed-precision batched inference kernels for large language models that exploit 4-bit weight quantization on Ampere GPUs to dramatically reduce memory movement while maintaining compute efficiency. The core idea is to align kernel design with the hardware hierarchy (SM, Warp, Tensor Core) and to aggressively pipeline loads, dequantization, and affine scaling, achieving near-peak performance for many practical batch sizes. An extension, Sparse-MARLIN, adds 2:4 sparsity to the weight representation, delivering additional speedups on suitable Tensor Cores. End-to-end evaluation with vLLM demonstrates up to 2.8× speedups in generation time for common models, with Sparse-MARLIN offering up to about 3.3×, indicating strong practical impact for multi-user serving scenarios and realistic deployment. The work also provides GPTQ-friendly modifications and releases the code openly, enabling broader adoption and extension to other compression schemes.

Abstract

As inference on Large Language Models (LLMs) emerges as an important workload in machine learning applications, weight quantization has become a standard technique for efficient GPU deployment. Quantization not only reduces model size, but has also been shown to yield substantial speedups for single-user inference, due to reduced memory movement, with low accuracy impact. Yet, it remains open whether speedups are achievable also in \emph{batched} settings with multiple parallel clients, which are highly relevant for practical serving. It is unclear whether GPU kernels can be designed to remain practically memory-bound, while supporting the substantially increased compute requirements of batched workloads. This paper resolves this question positively by describing the design of Mixed-precision Auto-Regressive LINear kernels, called MARLIN. Concretely, given a model whose weights are compressed via quantization to, e.g., 4 bits per element, MARLIN shows that batchsizes up to 16-32 can be supported with close to maximum () quantization speedup, and larger batchsizes up to 64-128 with gradually decreasing, but still significant, acceleration. MARLIN accomplishes this via a combination of techniques, such as asynchronous memory access, complex task scheduling and pipelining, and bespoke quantization support. Our experiments show that MARLIN's near-optimal performance on individual LLM layers across different scenarios can also lead to end-to-end LLM inference speedups (of up to ) when integrated with the popular vLLM serving engine. Finally, MARLIN is extensible to further compression techniques, like NVIDIA 2:4 sparsity, leading to additional speedups.
Paper Structure (38 sections, 2 equations, 16 figures, 2 tables, 1 algorithm)

This paper contains 38 sections, 2 equations, 16 figures, 2 tables, 1 algorithm.

Figures (16)

  • Figure 1: Illustration of MARLIN peak performance while increasing batch size, for a single large linear LLM layer, compared with other popular open-source kernels, showing that we can achieve near-optimal performance in this scenario.
  • Figure 2: Illustration of asynchronous copy operation with and without L1 bypass (right) vs. standard operations (left).
  • Figure 3: Levels of pipelining in the MARLIN kernel.
  • Figure 4: Illustration of MARLIN's warp layout. Multiple warps accumulate partial results of the same output tile; see also Algorithm \ref{['alg:warp-layout']} for corresponding pseudocode.
  • Figure 5: MARLIN's striped partitioning scheme.
  • ...and 11 more figures