Table of Contents
Fetching ...

Flash Communication: Reducing Tensor Parallelization Bottleneck for Fast Large Language Model Inference

Qingyuan Li, Bo Zhang, Liang Ye, Yifan Zhang, Wei Wu, Yerui Sun, Lin Ma, Yuchen Xie

TL;DR

This paper tackles the tensor-parallel inference bottleneck in large language models by introducing Flash Communication, a low-bit activation quantization technique paired with a two-step All-Reduce. A fused CUDA kernel implements the approach, dramatically reducing intra-node communication time and lowering time-to-first-token with minimal accuracy impact. Through extensive experiments on LLaMA-2/3 models across GPUs (L40 and A100), the method delivers up to 2x TTFT improvement and demonstrates robust accuracy preservation across benchmarks. The work offers a practical path to faster, scalable LLM inference by optimizing communication volume and reducing reduction hops in tensor-parallel setups.

Abstract

The ever-increasing sizes of large language models necessitate distributed solutions for fast inference that exploit multi-dimensional parallelism, where computational loads are split across various accelerators such as GPU clusters. However, this approach often introduces significant communication overhead, especially on devices with limited bandwidth. In this paper, we introduce Flash Communication, a novel low-bit compression technique designed to alleviate the tensor-parallelism communication bottleneck during inference. Our method substantially boosts intra-node communication speed by more than 3x and reduces the time-to-first-token by 2x, with nearly no sacrifice in model accuracy. Extensive experiments on various up-to-date LLMs demonstrate the effectiveness of our approach.

Flash Communication: Reducing Tensor Parallelization Bottleneck for Fast Large Language Model Inference

TL;DR

This paper tackles the tensor-parallel inference bottleneck in large language models by introducing Flash Communication, a low-bit activation quantization technique paired with a two-step All-Reduce. A fused CUDA kernel implements the approach, dramatically reducing intra-node communication time and lowering time-to-first-token with minimal accuracy impact. Through extensive experiments on LLaMA-2/3 models across GPUs (L40 and A100), the method delivers up to 2x TTFT improvement and demonstrates robust accuracy preservation across benchmarks. The work offers a practical path to faster, scalable LLM inference by optimizing communication volume and reducing reduction hops in tensor-parallel setups.

Abstract

The ever-increasing sizes of large language models necessitate distributed solutions for fast inference that exploit multi-dimensional parallelism, where computational loads are split across various accelerators such as GPU clusters. However, this approach often introduces significant communication overhead, especially on devices with limited bandwidth. In this paper, we introduce Flash Communication, a novel low-bit compression technique designed to alleviate the tensor-parallelism communication bottleneck during inference. Our method substantially boosts intra-node communication speed by more than 3x and reduces the time-to-first-token by 2x, with nearly no sacrifice in model accuracy. Extensive experiments on various up-to-date LLMs demonstrate the effectiveness of our approach.

Paper Structure

This paper contains 27 sections, 4 equations, 14 figures, 7 tables, 1 algorithm.

Figures (14)

  • Figure 1: Prefill cost breakdown of LLaMA-3-70B operations with and without Flash Communication, as measured by NSys NVIDIANsightSystems. Tested on 4$\times$L40/A100 GPUs (TP=4) with a batch size of 8, each with 1024 input and 64 output tokens. NCCL nvidia2024nccl's Ring All-Reduce is applied. The notion of x-ticks (e.g. L40 FP16/FP16) denotes GPU type, model weight precision, and communication precision, respectively.
  • Figure 2: Prefill cost breakdown of LLaMA-3-70B operations at various sequence lengths. Tested on 4$\times$L40 GPUs (TP=4) with a batch size of 8.
  • Figure 3: Tensor parallelism for a LLaMA-3 transformer block. Our Flash All-Reduce is applied to speed up communication.
  • Figure 4: Activation quantization with various block sizes of LLaMA-3-8B on C4. Starting from 4096 (the length of hidden dimension), the granularity becomes finer till 128.
  • Figure 5: Left: Comparison of $o_{proj}$ and $d_{proj}$ All-Reduce Quantization MSE. Right: MSE of quantization before Reduce-Scatter (RS) vs. All-Gather (AG).
  • ...and 9 more figures