Table of Contents
Fetching ...

Context Parallelism for Scalable Million-Token Inference

Amy Yang, Jingyi Yang, Aya Ibrahim, Xinfeng Xie, Bangsheng Tang, Grigory Sizov, Jeremy Reizenstein, Jongsoo Park, Jianyu Huang

TL;DR

This paper tackles the latency and scalability bottlenecks of long-context LLM inference by introducing context parallelism (CP) and two lossless ring attention variants, pass-KV and pass-Q. CP distributes tokens across multiple ranks and uses load-balanced sharding and ring-based KV/Q communication to overlap computation and communication, achieving near-linear scaling across up to 128 GPUs. The authors demonstrate state-of-the-art performance on Llama3 405B, obtaining 1M context prefill in 77s and 128K prefill in 3.8s across 16 nodes, with high MFU efficiency and robust performance across RDMA and TCP interconnects. They also provide adaptive heuristics to switch between pass-KV and pass-Q based on KV cache hit rates, and discuss decode performance and future directions toward combining exact context with retrieval for ultra-long contexts.

Abstract

We present context parallelism for long-context large language model inference, which achieves near-linear scaling for long-context prefill latency with up to 128 H100 GPUs across 16 nodes. Particularly, our method achieves 1M context prefill with Llama3 405B model in 77s (93% parallelization efficiency, 63% FLOPS utilization) and 128K context prefill in 3.8s. We develop two lossless exact ring attention variants: pass-KV and pass-Q to cover a wide range of use cases with the state-of-the-art performance: full prefill, persistent KV prefill and decode. Benchmarks on H100 GPU hosts inter-connected with RDMA and TCP both show similar scalability for long-context prefill, demonstrating that our method scales well using common commercial data center with medium-to-low inter-host bandwidth.

Context Parallelism for Scalable Million-Token Inference

TL;DR

This paper tackles the latency and scalability bottlenecks of long-context LLM inference by introducing context parallelism (CP) and two lossless ring attention variants, pass-KV and pass-Q. CP distributes tokens across multiple ranks and uses load-balanced sharding and ring-based KV/Q communication to overlap computation and communication, achieving near-linear scaling across up to 128 GPUs. The authors demonstrate state-of-the-art performance on Llama3 405B, obtaining 1M context prefill in 77s and 128K prefill in 3.8s across 16 nodes, with high MFU efficiency and robust performance across RDMA and TCP interconnects. They also provide adaptive heuristics to switch between pass-KV and pass-Q based on KV cache hit rates, and discuss decode performance and future directions toward combining exact context with retrieval for ultra-long contexts.

Abstract

We present context parallelism for long-context large language model inference, which achieves near-linear scaling for long-context prefill latency with up to 128 H100 GPUs across 16 nodes. Particularly, our method achieves 1M context prefill with Llama3 405B model in 77s (93% parallelization efficiency, 63% FLOPS utilization) and 128K context prefill in 3.8s. We develop two lossless exact ring attention variants: pass-KV and pass-Q to cover a wide range of use cases with the state-of-the-art performance: full prefill, persistent KV prefill and decode. Benchmarks on H100 GPU hosts inter-connected with RDMA and TCP both show similar scalability for long-context prefill, demonstrating that our method scales well using common commercial data center with medium-to-low inter-host bandwidth.

Paper Structure

This paper contains 29 sections, 19 equations, 10 figures, 9 tables, 5 algorithms.

Figures (10)

  • Figure 1: Load-balanced CP sharding with fused inputs in full prefill with 2 CP ranks (CP2). We have 2 input sequences: $S1$, $S2$. Each is partitioned evenly into 4 chunks: $Q_i$ / $K_i$, where $i=1, 2, 3, 4$.
  • Figure 2: Load-balanced CP sharding with fused inputs partial prefill with 2 CP ranks (CP2). We have 2 input sequences: $S1$, $S2$. Load-balanced sharding is applied to the new token $Q_i$ dimension (4 chunks), regardless of how cached token dimension $K_i$ is partitioned in partial prefill.
  • Figure 3: Ring Pass-KV Attention with 4 CP ranks (CP4).
  • Figure 4: Ring Pass-Q Attention with 4 CP ranks (CP4).
  • Figure 5: Context parallel across nodes and tensor parallel within nodes, with 2 CP ranks (CP2).
  • ...and 5 more figures