Table of Contents
Fetching ...

Towards Low-bit Communication for Tensor Parallel LLM Inference

Harry Dong, Tyler Johnson, Minsik Cho, Emad Soroush

TL;DR

This work introduces a quantization method that reduces communicated values on average from 16 bits to 4.2 bits while preserving nearly all of the original performance of server large language model inference efficiency.

Abstract

Tensor parallelism provides an effective way to increase server large language model (LLM) inference efficiency despite adding an additional communication cost. However, as server LLMs continue to scale in size, they will need to be distributed across more devices, magnifying the communication cost. One way to approach this problem is with quantization, but current methods for LLMs tend to avoid quantizing the features that tensor parallelism needs to communicate. Taking advantage of consistent outliers in communicated features, we introduce a quantization method that reduces communicated values on average from 16 bits to 4.2 bits while preserving nearly all of the original performance. For instance, our method maintains around 98.0% and 99.5% of Gemma 2 27B's and Llama 2 13B's original performance, respectively, averaged across all tasks we evaluated on.

Towards Low-bit Communication for Tensor Parallel LLM Inference

TL;DR

This work introduces a quantization method that reduces communicated values on average from 16 bits to 4.2 bits while preserving nearly all of the original performance of server large language model inference efficiency.

Abstract

Tensor parallelism provides an effective way to increase server large language model (LLM) inference efficiency despite adding an additional communication cost. However, as server LLMs continue to scale in size, they will need to be distributed across more devices, magnifying the communication cost. One way to approach this problem is with quantization, but current methods for LLMs tend to avoid quantizing the features that tensor parallelism needs to communicate. Taking advantage of consistent outliers in communicated features, we introduce a quantization method that reduces communicated values on average from 16 bits to 4.2 bits while preserving nearly all of the original performance. For instance, our method maintains around 98.0% and 99.5% of Gemma 2 27B's and Llama 2 13B's original performance, respectively, averaged across all tasks we evaluated on.

Paper Structure

This paper contains 7 sections, 4 equations, 3 figures, 1 table.

Figures (3)

  • Figure 1: Sorted aggregated quantization ranges, $\Bar{\bm{R}}_j$, of each each attention (left) and feedforward (right) block in Gemma 2 27B, with the mean across all layers in red. Values are scaled such that the max range for each layer is set to 1.
  • Figure 2: Our hybrid quantization algorithm. A small set of features are selected based on aggregated quantization ranges to be kept at BF16 while all others are quantized to Int4 prior to inter-device communication. Then, all tensors are converted to BF16 and summed to sync across each device.
  • Figure 3: Gemma 2 27B performance when features are quantized to varying numbers of bits. Our method achieves the best accuracy for every quantization precision.