Table of Contents
Fetching ...

DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference

Jinwei Yao, Kaiqi Chen, Kexun Zhang, Jiaxuan You, Binhang Yuan, Zeke Wang, Tao Lin

TL;DR

This paper tackles IO-bound bottlenecks in tree-structured LLM inference due to shared prefixes by introducing DeFT-Flatten, a hardware-efficient attention approach with KV-Guided Grouping and Flattened Tree KV Splitting. By reusing KV cache IO for prefixes and balancing KV partitions across GPU cores, DeFT reduces KV IO dramatically and achieves substantial speedups in end-to-end decoding and attention across few-shot, multi-step, and speculative decoding tasks. The method is implemented in OpenAI Triton, includes a two-phase attention kernel, and is supported by a system framework for tree KV cache management, yielding up to 2.23x end-to-end and 3.59x attention speedups. Ablation shows balanced partitioning and longer prompts amplify gains, suggesting strong practical impact for scalable tree-based LLM inference.

Abstract

Large language models (LLMs) are increasingly employed for complex tasks that process multiple generation calls in a tree structure with shared prefixes of tokens, including few-shot prompting, multi-step reasoning, speculative decoding, etc. However, existing inference systems for tree-based applications are inefficient due to improper partitioning of queries and KV cache during attention calculation. This leads to two main issues: (1) a lack of memory access (IO) reuse for KV cache of shared prefixes, and (2) poor load balancing.As a result, there is redundant KV cache IO between GPU global memory and shared memory, along with low GPU utilization. To address these challenges, we propose DeFT(Decoding with Flash Tree-Attention), a hardware-efficient attention algorithm with prefix-aware and load-balanced KV cache partitions. DeFT reduces the number of read/write operations of KV cache during attention calculation through KV-Guided Grouping, a method that avoids repeatedly loading KV cache of shared prefixes in attention computation. Additionally, we propose Flattened Tree KV Splitting, a mechanism that ensures even distribution of the KV cache across partitions with little computation redundancy, enhancing GPU utilization during attention computations. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation, DeFT achieves up to 2.23/3.59x speedup in the end-to-end/attention latency across three practical tree-based workloads compared to state-of-the-art attention algorithms. Our code is available at https://github.com/LINs-lab/DeFT.

DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference

TL;DR

This paper tackles IO-bound bottlenecks in tree-structured LLM inference due to shared prefixes by introducing DeFT-Flatten, a hardware-efficient attention approach with KV-Guided Grouping and Flattened Tree KV Splitting. By reusing KV cache IO for prefixes and balancing KV partitions across GPU cores, DeFT reduces KV IO dramatically and achieves substantial speedups in end-to-end decoding and attention across few-shot, multi-step, and speculative decoding tasks. The method is implemented in OpenAI Triton, includes a two-phase attention kernel, and is supported by a system framework for tree KV cache management, yielding up to 2.23x end-to-end and 3.59x attention speedups. Ablation shows balanced partitioning and longer prompts amplify gains, suggesting strong practical impact for scalable tree-based LLM inference.

Abstract

Large language models (LLMs) are increasingly employed for complex tasks that process multiple generation calls in a tree structure with shared prefixes of tokens, including few-shot prompting, multi-step reasoning, speculative decoding, etc. However, existing inference systems for tree-based applications are inefficient due to improper partitioning of queries and KV cache during attention calculation. This leads to two main issues: (1) a lack of memory access (IO) reuse for KV cache of shared prefixes, and (2) poor load balancing.As a result, there is redundant KV cache IO between GPU global memory and shared memory, along with low GPU utilization. To address these challenges, we propose DeFT(Decoding with Flash Tree-Attention), a hardware-efficient attention algorithm with prefix-aware and load-balanced KV cache partitions. DeFT reduces the number of read/write operations of KV cache during attention calculation through KV-Guided Grouping, a method that avoids repeatedly loading KV cache of shared prefixes in attention computation. Additionally, we propose Flattened Tree KV Splitting, a mechanism that ensures even distribution of the KV cache across partitions with little computation redundancy, enhancing GPU utilization during attention computations. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation, DeFT achieves up to 2.23/3.59x speedup in the end-to-end/attention latency across three practical tree-based workloads compared to state-of-the-art attention algorithms. Our code is available at https://github.com/LINs-lab/DeFT.
Paper Structure (46 sections, 18 figures, 21 tables, 4 algorithms)

This paper contains 46 sections, 18 figures, 21 tables, 4 algorithms.

Figures (18)

  • Figure 1: An illustration of Sequence-based decoding and Tree-based decoding.
  • Figure 2: Overview of DeFT. Input Metadata is prepared in the system elaborated in Appendix \ref{['comp_func']}. In QKV Preparation Phase (see Section \ref{['sec:DeFT_Attn']}), the QKV will be grouped logically to partitions with IO-awareness of shared prefixes' KV cache and load-balancing. These partitions will guide the loading of QKV on the Attention Calculation Phase (see Appendix \ref{['A:Tech']}), where the attention calculation will be executed.
  • Figure 3: Comparison of QKV partitioning strategies during the QKV Preparation Phase between DeFT-Node/Node-Chunk/Flatten and different attention algorithm baselines. Note that the partitioning is logically designed without incurring any data movement costs for QKV. The amount of IO between the GPU HBM and shared memory required by each group is highlighted in red rectangles. Part (a) illustrates the dataflow of a two-cascaded decoding tree example and three categories of QKV partitioning strategies: no partition(Vanilla Tree Attention), Q-Guided Grouping and KV-Guided Grouping. The partitioning strategy will guide the loading of QKV during the subsequent Attention calculation phase, where each QKV group $G_i$ will be loaded into $SM_i$ on the GPU. Part (b) shows the comparison of Q-Guided Grouping and KV-Guided Grouping, where the latter can be IO-aware of prefix KV cache $KV_0$ and only load it once. DeFT-Node-Chunk is a weak load-balancing improvement of DeFT-Node by splitting large nodes (e.g., $KV_0$) to chunks.Part (c) illustrates the details (discussed in Remark \ref{['rmk:flatten']}) of Flattened Tree KV Splitting in DeFT-Flatten for load-balanced partitions, including Depth-first Flatten strategy, Evenly block-wise strategy, and Bit mask. For a summary of baselines and DeFT, see \ref{['tab:grouping']}. See analysis of tree-attention baselines cai2024medusamiao2023specinfer in Remark \ref{['rmk:group_treeattn']}.
  • Figure 4: Latency breakdown for speculative decoding with a token tree of 32 queries, whose tree topology is from Medusa cai2024medusa.U means unpaged memory.
  • Figure 5: Illustration of DeFT. (Left) System overview. (Right) The data flow of DeFT-Node (DeFT-Flatten is similar except for QKV partitioning) using a decoding tree example.
  • ...and 13 more figures

Theorems & Definitions (7)

  • Remark 3.1: Techniques of Flattened Tree KV Splitting
  • Remark 3.2: Discussion on Tree Attention Algorithms
  • Remark A.1: Importance of tiling and fused kernel during Attention Calculation Phase
  • Remark A.2: The effects of introducing a causal mask
  • Remark A.3: KV IO in SpecInfer
  • Remark A.4: IO in Radix Attention
  • Remark A.5: Causal mask IO