Table of Contents
Fetching ...

KV-Runahead: Scalable Causal LLM Inference by Parallel Key-Value Cache Generation

Minsik Cho, Mohammad Rastegari, Devang Naik

TL;DR

This work tackles the long TTFT bottleneck in prompt-phase causal LLM inference by proposing KV-Runahead, a parallelization scheme that dual-purposes the KV-cache to populate required keys and values across multiple processes. It leverages uneven, context-level load-balancing and asynchronous, neighbor-to-neighbor KV-cache transfers to minimize QK^T computations and data movement, avoiding global synchronization. The authors provide a theoretical bound on TTFT across processes and implement a hierarchical grid search to build a partitioning lookup table for TTFT optimization, showing real-world speedups of up to $1.6×$ on Llama 7B and Falcon 7B under various network conditions. The approach is easy to integrate with existing KV-cache-enabled LLMs and offers robustness to bandwidth fluctuations, enabling scalable, low-latency prompting for long-context and retrieval-augmented tasks.

Abstract

Large Language Model or LLM inference has two phases, the prompt (or prefill) phase to output the first token and the extension (or decoding) phase to the generate subsequent tokens. In this work, we propose an efficient parallelization scheme, KV-Runahead to accelerate the prompt phase. The key observation is that the extension phase generates tokens faster than the prompt phase because of key-value cache (KV-cache). Hence, KV-Runahead parallelizes the prompt phase by orchestrating multiple processes to populate the KV-cache and minimizes the time-to-first-token (TTFT). Dual-purposing the KV-cache scheme has two main benefits. First, since KV-cache is designed to leverage the causal attention map, we minimize computation and computation automatically. Second, since it already exists for the extension phase, KV-Runahead is easy to implement. We further propose context-level load-balancing to handle uneven KV-cache generation (due to the causal attention) and to optimize TTFT. Compared with an existing parallelization scheme such as tensor or sequential parallelization where keys and values are locally generated and exchanged via all-gather collectives, our experimental results demonstrate that KV-Runahead can offer over 1.4x and 1.6x speedups for Llama 7B and Falcon 7B respectively.

KV-Runahead: Scalable Causal LLM Inference by Parallel Key-Value Cache Generation

TL;DR

This work tackles the long TTFT bottleneck in prompt-phase causal LLM inference by proposing KV-Runahead, a parallelization scheme that dual-purposes the KV-cache to populate required keys and values across multiple processes. It leverages uneven, context-level load-balancing and asynchronous, neighbor-to-neighbor KV-cache transfers to minimize QK^T computations and data movement, avoiding global synchronization. The authors provide a theoretical bound on TTFT across processes and implement a hierarchical grid search to build a partitioning lookup table for TTFT optimization, showing real-world speedups of up to on Llama 7B and Falcon 7B under various network conditions. The approach is easy to integrate with existing KV-cache-enabled LLMs and offers robustness to bandwidth fluctuations, enabling scalable, low-latency prompting for long-context and retrieval-augmented tasks.

Abstract

Large Language Model or LLM inference has two phases, the prompt (or prefill) phase to output the first token and the extension (or decoding) phase to the generate subsequent tokens. In this work, we propose an efficient parallelization scheme, KV-Runahead to accelerate the prompt phase. The key observation is that the extension phase generates tokens faster than the prompt phase because of key-value cache (KV-cache). Hence, KV-Runahead parallelizes the prompt phase by orchestrating multiple processes to populate the KV-cache and minimizes the time-to-first-token (TTFT). Dual-purposing the KV-cache scheme has two main benefits. First, since KV-cache is designed to leverage the causal attention map, we minimize computation and computation automatically. Second, since it already exists for the extension phase, KV-Runahead is easy to implement. We further propose context-level load-balancing to handle uneven KV-cache generation (due to the causal attention) and to optimize TTFT. Compared with an existing parallelization scheme such as tensor or sequential parallelization where keys and values are locally generated and exchanged via all-gather collectives, our experimental results demonstrate that KV-Runahead can offer over 1.4x and 1.6x speedups for Llama 7B and Falcon 7B respectively.
Paper Structure (14 sections, 3 equations, 11 figures, 5 tables)

This paper contains 14 sections, 3 equations, 11 figures, 5 tables.

Figures (11)

  • Figure 1: LLM inference begins with the prompt phase to generate the KV-cache and the first token which drive the extension phase as in (a). Inside each layer of the LLM as in (b), a causal attention map ($QK^T$) is built to compute the attention $A$ from query, value, and key s $(Q,K,V)$. Computing attention thus has $O(C^2)$ complexity where $C$ is the user context.
  • Figure 2: $QK^T$ computation coverage using BLAS matrix-matrix multiplication: by linking each context partition to KV-cache, we can closely approximate the lower triangular part and minimize unnecessary dot products. Note the upper triangular part of the attention will be masked out to enforce causality.
  • Figure 3: Comparing the existing tensor+sequence parallel scheme with the proposed KV-Runahead for parallel LLM inference.
  • Figure 4: Tensor/sequence parallel inference over 3 processes $p_{\{0,1,2\}}$ within a layer to compute attention map ($QK^T$) and final attention $A$: Each process will compute the equal amount of $(Q,K,V)$ in (a), and then globally share ($K,V$) using all-gather collectives to compute the equally sized partial $QK^T$ (i.e., 27 dot-products needed on each) and partial $A$. Such all-gather operations require global synchronization, and incur the traffic for 36$(K,V)$ entries (i.e., the number of blue rows in $(K,V)$).
  • Figure 5: KV-Runahead execution over 3 processes $p_{\{0,1,2\}}$ within a layer to compute attention map ($QK^T$) and final attention $A$: Each process will compute different amounts of $(Q,K,V)$ in (a), and the maximum amount for $QK^T$ is 21 dot-products on $p_1$ (in contrast to 27 from Fig. \ref{['dp::before_after']} (b)). The locally computed $(K, V)$ are passed down to the following processes as KV-cache using point-2-point one-way send (i.e., $p_0 \rightarrow p_1 \rightarrow p_2$). Our communication is much cheaper than global all-gather in Fig. \ref{['dp::before_after']} (b), as the traffic incurred in KV-Runahead is 22 (i.e., the number of blue rows in $(K,V)$), which is much lower than 36 from Fig. \ref{['dp::before_after']}.
  • ...and 6 more figures