FlashForge: Ultra-Efficient Prefix-Aware Attention for LLM Decoding
Zhibin Wang, Rui Ning, Chao Fang, Zhonghui Zhang, Xi Lin, Shaobo Ma, Mo Zhou, Xue Li, Zhongfeng Wang, Chengying Huan, Rong Gu, Kun Yang, Guihai Chen, Sheng Zhong, Chen Tian
TL;DR
This work targets the decode-stage bottleneck in prefix-sharing for large language models by introducing FlashForge, a dedicated prefix-shared decoding operator. It combines a novel shared-prefix attention kernel with a tree-aware KV-cache management and a workload-balancing framework (cost estimator, task division, and scheduler) to address irregular workloads and memory-bound access patterns. The approach yields substantial gains, including an average $1.9\times$ speedup and $120.9\times$ memory-access reduction over the previous state (FlashDecoding), and up to $3.8\times$ faster end-to-end latency versus vLLM, demonstrating strong practical impact for long-context inference. These results underscore the importance of co-designing data structures (KV-cache trees) with memory hierarchy-aware kernels and adaptive scheduling to unlock efficient prefix-sharing at scale.
Abstract
Prefix-sharing among multiple prompts presents opportunities to combine the operations of the shared prefix, while attention computation in the decode stage, which becomes a critical bottleneck with increasing context lengths, is a memory-intensive process requiring heavy memory access on the key-value (KV) cache of the prefixes. Therefore, in this paper, we explore the potential of prefix-sharing in the attention computation of the decode stage. However, the tree structure of the prefix-sharing mechanism presents significant challenges for attention computation in efficiently processing shared KV cache access patterns while managing complex dependencies and balancing irregular workloads. To address the above challenges, we propose a dedicated attention kernel to combine the memory access of shared prefixes in the decoding stage, namely FlashForge. FlashForge delivers two key innovations: a novel shared-prefix attention kernel that optimizes memory hierarchy and exploits both intra-block and inter-block parallelism, and a comprehensive workload balancing mechanism that efficiently estimates cost, divides tasks, and schedules execution. Experimental results show that FlashForge achieves an average 1.9x speedup and 120.9x memory access reduction compared to the state-of-the-art FlashDecoding kernel regarding attention computation in the decode stage and 3.8x end-to-end time per output token compared to the vLLM.
