Table of Contents
Fetching ...

APB: Accelerating Distributed Long-Context Inference by Passing Compressed Context Blocks across GPUs

Yuxiang Huang, Mingye Li, Xu Han, Chaojun Xiao, Weilin Zhao, Sun Ao, Hao Zhou, Jie Zhou, Zhiyuan Liu, Maosong Sun

TL;DR

Long-context inference in Transformer-based LLMs is bottlenecked by $O(n^2)$ attention during the prefill stage, hindering scaling to ultra-long inputs. APB proposes a distributed framework that passes compressed KV context blocks across $H$ GPUs, using anchor blocks, passing blocks, localized KV compression via retaining heads, and AllGather-based communication, built atop a FlashAttn kernel. Across ∞Bench and RULER, APB delivers up to 9.2x speedups over FlashAttn, 4.2x over RingAttn, and 1.6x over StarAttn while maintaining or improving task performance, demonstrating robust performance across models and input lengths. By reducing attention compute and enabling scalable cross-host sharing of essential context, APB enables practical deployment of ultra-long-context inference for diverse applications and model sizes; it also remains compatible with KV quantization, broadening integration potential.

Abstract

While long-context inference is crucial for advancing large language model (LLM) applications, its prefill speed remains a significant bottleneck. Current approaches, including sequence parallelism strategies and compute reduction through approximate attention mechanisms, still fall short of delivering optimal inference efficiency. This hinders scaling the inputs to longer sequences and processing long-context queries in a timely manner. To address this, we introduce APB, an efficient long-context inference framework that leverages multi-host approximate attention to enhance prefill speed by reducing compute and enhancing parallelism simultaneously. APB introduces a communication mechanism for essential key-value pairs within a sequence parallelism framework, enabling a faster inference speed while maintaining task performance. We implement APB by incorporating a tailored FlashAttn kernel alongside optimized distribution strategies, supporting diverse models and parallelism configurations. APB achieves speedups of up to 9.2x, 4.2x, and 1.6x compared with FlashAttn, RingAttn, and StarAttn, respectively, without any observable task performance degradation. We provide the implementation and experiment code of APB in https://github.com/thunlp/APB.

APB: Accelerating Distributed Long-Context Inference by Passing Compressed Context Blocks across GPUs

TL;DR

Long-context inference in Transformer-based LLMs is bottlenecked by attention during the prefill stage, hindering scaling to ultra-long inputs. APB proposes a distributed framework that passes compressed KV context blocks across GPUs, using anchor blocks, passing blocks, localized KV compression via retaining heads, and AllGather-based communication, built atop a FlashAttn kernel. Across ∞Bench and RULER, APB delivers up to 9.2x speedups over FlashAttn, 4.2x over RingAttn, and 1.6x over StarAttn while maintaining or improving task performance, demonstrating robust performance across models and input lengths. By reducing attention compute and enabling scalable cross-host sharing of essential context, APB enables practical deployment of ultra-long-context inference for diverse applications and model sizes; it also remains compatible with KV quantization, broadening integration potential.

Abstract

While long-context inference is crucial for advancing large language model (LLM) applications, its prefill speed remains a significant bottleneck. Current approaches, including sequence parallelism strategies and compute reduction through approximate attention mechanisms, still fall short of delivering optimal inference efficiency. This hinders scaling the inputs to longer sequences and processing long-context queries in a timely manner. To address this, we introduce APB, an efficient long-context inference framework that leverages multi-host approximate attention to enhance prefill speed by reducing compute and enhancing parallelism simultaneously. APB introduces a communication mechanism for essential key-value pairs within a sequence parallelism framework, enabling a faster inference speed while maintaining task performance. We implement APB by incorporating a tailored FlashAttn kernel alongside optimized distribution strategies, supporting diverse models and parallelism configurations. APB achieves speedups of up to 9.2x, 4.2x, and 1.6x compared with FlashAttn, RingAttn, and StarAttn, respectively, without any observable task performance degradation. We provide the implementation and experiment code of APB in https://github.com/thunlp/APB.

Paper Structure

This paper contains 31 sections, 3 equations, 7 figures, 20 tables, 3 algorithms.

Figures (7)

  • Figure 1: The prefill speed of methods with and without sequence parallelism when processing different input lengths. "SP" indicates sequence parallelism. "x" represents that the setting triggers out-of-memory error.
  • Figure 2: The framework of APB. The input document $d$ is split into blocks $\textcolor{rgb(0,0,0)}{\mathbf{B}}_1, \textcolor{rgb(0,0,0)}{\mathbf{B}}_2, \textcolor{rgb(0,0,0)}{\mathbf{B}}_3$ and distributed across 3 hosts. The anchor block is denoted as "$\textcolor{rgb(0,0,0)}{\mathbf{A}}$", the passing block as "$\textcolor{rgb(0,0,0)}{\mathbf{P}}$", and the compressor as "$\mathcal{C}$". Each block is first prepended with an anchor block. When calculating attention, the context block $\textcolor{rgb(0,0,0)}{\mathbf{B}}$ is compressed into $\textcolor{rgb(0,0,0)}{\mathbf{B}}^C$ using the compressor $\mathcal{C}$. Subsequently, the passing block is constructed after an AllGather communication. Finally, attention is performed using a modified attention mask. Passing blocks are discarded after the attention computation.
  • Figure 3: The inference speed and model performance of APB and all the baselines. The top-right direction represents the optimal tradeoff between speed and performance. APB achieves the best tradeoff of the two metrics. FlashAttn, RingAttn, and Ulysses share the same performance as they are all FullAttn methods.
  • Figure 4: The performance, speed, and the amount of compute of various methods under different input lengths. APB consistently outperforms other methods with better performance, faster speed, and lower compute.
  • Figure 5: The wall-time breakdown of prefill for various methods on 128K context.
  • ...and 2 more figures