Acceleration Multiple Heads Decoding for LLM via Dynamic Tree Attention
Zhendong Zhang
TL;DR
Problem: latency bottlenecks in LLM inference arise from memory bandwidth and the sequential nature of autoregressive decoding. Approach: replace fixed MEDUSA tree with dynamic tree attention and a lightweight candidate-generation pipeline that uses Cartesian-product sampling across $K$ heads with $K=4$, $n=64$, $m=32$, and computes $P(i_1,...,i_k) = \prod_{j=1}^k p_{i_j}^{(j)}$ to select top candidates; construct dynamic-tree buffers with $O(K n)$ complexity. Findings: experiments on Vicuna-7B and MT-Bench show improved decoding efficiency with maintained generation quality, albeit with ~10% slower wall-clock time due to current overhead; code is publicly available. Significance: demonstrates that context-aware dynamic tree structures can accelerate multi-head decoding in LLMs with modest overhead and manageable complexity, offering a practical route to reduce inference latency.
Abstract
Multiple heads decoding accelerates the inference of Large Language Models (LLMs) by predicting next several tokens simultaneously. It generates and verifies multiple candidate sequences in parallel via tree attention with a fixed structure. In this paper, we replace the fixed tree attention with dynamic tree attention on multiple head decoding, specifically in the context of MEDUSA. We propose a simple and low complexity strategy to generate candidates and construct the dynamic tree structure. Preliminary experiments show that the proposed method improves the decoding efficiency of multiple head decoding for LLMs while maintaining the generation quality. This result demonstrates the potential for improvement of multiple head decoding in candidate generation.
