Table of Contents
Fetching ...

Bifurcated Attention: Accelerating Massively Parallel Decoding with Shared Prefixes in LLMs

Ben Athiwaratkun, Sujan Kumar Gonugondla, Sanjay Krishna Gouda, Haifeng Qian, Hantian Ding, Qing Sun, Jun Wang, Jiacheng Guo, Liangfu Chen, Parminder Bhatia, Ramesh Nallapati, Sudipta Sengupta, Bing Xiang

TL;DR

This work tackles the latency bottleneck in massively parallel autoregressive decoding caused by memory IO in the KV-cache during incremental decoding. It introduces context-aware bifurcated attention, which splits attention into context- and decoding-related KV components to preserve FLOPs while dramatically reducing memory IO, enabling much larger batch sizes and longer contexts with minimal latency impact. The study analyzes generalized multi-group attention across the spectrum from multi-head to multi-query, showing how the number of attention groups $g$ trades expressiveness for efficiency and how model size can compensate for reduced expressiveness. Empirical results demonstrate substantial speedups (over 2× to 6×) in various configurations on 7B-scale models, compatibility with Torch-Compile, and practical benefits for applications requiring multiple parallel samples or re-ranking, such as code generation and reasoning tasks. Overall, bifurcated attention offers a practical, IO-aware pathway to real-time, high-throughput decoding in LLMs without sacrificing accuracy, making it valuable for latency-sensitive AI services.

Abstract

This study introduces bifurcated attention, a method designed to enhance language model inference in shared-context batch decoding scenarios. Our approach addresses the challenge of redundant memory IO costs, a critical factor contributing to latency in high batch sizes and extended context lengths. Bifurcated attention achieves this by strategically dividing the attention mechanism during incremental decoding into two separate GEMM operations: one focusing on the KV cache from prefill, and another on the decoding process itself. While maintaining the computational load (FLOPs) of standard attention mechanisms, bifurcated attention ensures precise computation with significantly reduced memory IO. Our empirical results show over 2.1$\times$ speedup when sampling 16 output sequences and more than 6.2$\times$ speedup when sampling 32 sequences at context lengths exceeding 8k tokens on a 7B model that uses multi-head attention. The efficiency gains from bifurcated attention translate into lower latency, making it particularly suitable for real-time applications. For instance, it enables massively parallel answer generation without substantially increasing latency, thus enhancing performance when integrated with post-processing techniques such as re-ranking.

Bifurcated Attention: Accelerating Massively Parallel Decoding with Shared Prefixes in LLMs

TL;DR

This work tackles the latency bottleneck in massively parallel autoregressive decoding caused by memory IO in the KV-cache during incremental decoding. It introduces context-aware bifurcated attention, which splits attention into context- and decoding-related KV components to preserve FLOPs while dramatically reducing memory IO, enabling much larger batch sizes and longer contexts with minimal latency impact. The study analyzes generalized multi-group attention across the spectrum from multi-head to multi-query, showing how the number of attention groups trades expressiveness for efficiency and how model size can compensate for reduced expressiveness. Empirical results demonstrate substantial speedups (over 2× to 6×) in various configurations on 7B-scale models, compatibility with Torch-Compile, and practical benefits for applications requiring multiple parallel samples or re-ranking, such as code generation and reasoning tasks. Overall, bifurcated attention offers a practical, IO-aware pathway to real-time, high-throughput decoding in LLMs without sacrificing accuracy, making it valuable for latency-sensitive AI services.

Abstract

This study introduces bifurcated attention, a method designed to enhance language model inference in shared-context batch decoding scenarios. Our approach addresses the challenge of redundant memory IO costs, a critical factor contributing to latency in high batch sizes and extended context lengths. Bifurcated attention achieves this by strategically dividing the attention mechanism during incremental decoding into two separate GEMM operations: one focusing on the KV cache from prefill, and another on the decoding process itself. While maintaining the computational load (FLOPs) of standard attention mechanisms, bifurcated attention ensures precise computation with significantly reduced memory IO. Our empirical results show over 2.1 speedup when sampling 16 output sequences and more than 6.2 speedup when sampling 32 sequences at context lengths exceeding 8k tokens on a 7B model that uses multi-head attention. The efficiency gains from bifurcated attention translate into lower latency, making it particularly suitable for real-time applications. For instance, it enables massively parallel answer generation without substantially increasing latency, thus enhancing performance when integrated with post-processing techniques such as re-ranking.
Paper Structure (50 sections, 5 equations, 10 figures, 8 tables)

This paper contains 50 sections, 5 equations, 10 figures, 8 tables.

Figures (10)

  • Figure 1: Illustration of the two phases of language model inference: context encoding and incremental decoding, as well as different inference scenarios. In batch inference scenario, we process multiple inputs at once and perform incremental decoding steps. In batch inference, we group multiple inputs in batch to perform both context encoding and the subsequent incremental decoding. In the single-context batch sampling scenario, we perform context encoding on a single input to obtain the context KV cache, then perform incremental decoding (with temperature sampling) to obtain potentially different generations.
  • Figure 2: Context-aware bifurcated attention for single-context batch sampling. The figure depicts the incremental decoding step where the batched query $q$ attends with the cached key tensor $K$ where different colors in the $q$ tensor correspond to different batch indices. The key tensor consists of two parts: key cache corresponding to the single context $K_c$ (which was computed during context encoding, as in Figure \ref{['fig:model_inference']}), and the key cache corresponding to previous incremental decoding steps $K_d$. The query-key attention is bifurcated into two parts, $\langle q,K_{c} \rangle$ and $\langle q, K_{d} \rangle$, and joined back via concatenation, resulting in an identical results using the same FLOPs but with lower memory IO (Eq. \ref{['eq:bifurcated_k']}). The weight-value attention is bifurcated similarly, as outlined in Eq. \ref{['eq:bifurcated_v']}.
  • Figure 3: (Left) The plots of validation loss versus model size demonstrate that the scaling laws curves of different attention mechanisms have different expressiveness or performance efficiency. That is, the capabilities given the same model size depends on $g$ where higher $g$ yields the best capabilities. (Right) We demonstrate a similar trend where we use code generation abilities as a proxy for general capabilities. Here, we average the execution pass rates evaluated on Multi-lingual HumanEval and MBXP benchmarks under 13 programming languages.
  • Figure 4: High-level latency comparison between an MH model and a larger MQ model with comparable capabilities. Overall, there's an overhead cost for the initial context encoding latency due the additional compute with the larger MQ model size. For low context and batch size, the per step latency of MQ is also slightly higher to start due to the memory IO required for larger model size, but does not change much as context length $m$ or batch size $b$ grow, as supposed to the multi-head case where the per step latency can grow more rapidly with respect to $m$ and $b$.
  • Figure 5: Incremental decoding (per step) latency and the context encoding latency, as a function of input context length. In this plot, we compare an multi-head model and an multi-query model of comparable capabilities, whose size is slightly larger. (Leftmost: Per-step incremental decoding latency) For low context length such as $m < 2500$, due to the larger size of the MQ model, the inference latency is higher. However, the growth with respect to context length of the MQ model is much lower (almost flat), resulting in lower per step latency when the context length is high. (Second: Context encoding latency) The context encoding latency depends on the FLOPs where the MH and MQ are quite similar. Note that the MQ model is slightly larger, and therefore corresponds to a steeper curve. (Third, Fourth): Total latency for 15 or 256 generated steps The two plots illustrates the total latency, which is the sum of context encoding and the the number of steps times incremental decoding latency. The benefits of MQ model becomes clear in the case of high decoding steps $(256)$ whereas in the case of $15$ generated tokens, the total latency of MQ can still be slightly higher than MH.
  • ...and 5 more figures