Table of Contents
Fetching ...

Recursive Speculative Decoding: Accelerating LLM Inference via Sampling Without Replacement

Wonseok Jeon, Mukul Gagrani, Raghavv Goel, Junyoung Park, Mingu Lee, Christopher Lott

TL;DR

Recursive Speculative Decoding is presented, a novel tree-based method that samples draft tokens without replacement and maximizes the diversity of the tree that outperforms the baseline methods, consistently for fixed draft sequence length and in most cases for fixed computational budgets at LLM.

Abstract

Speculative decoding is an inference-acceleration method for large language models (LLMs) where a small language model generates a draft-token sequence which is further verified by the target LLM in parallel. Recent works have advanced this method by establishing a draft-token tree, achieving superior performance over a single-sequence speculative decoding. However, those works independently generate tokens at each level of the tree, not leveraging the tree's entire diversifiability. Besides, their empirical superiority has been shown for fixed length of sequences, implicitly granting more computational resource to LLM for the tree-based methods. None of the existing works has conducted empirical studies with fixed target computational budgets despite its importance to resource-bounded devices. We present Recursive Speculative Decoding (RSD), a novel tree-based method that samples draft tokens without replacement and maximizes the diversity of the tree. During RSD's drafting, the tree is built by either Gumbel-Top-$k$ trick that draws tokens without replacement in parallel or Stochastic Beam Search that samples sequences without replacement while early-truncating unlikely draft sequences and reducing the computational cost of LLM. We empirically evaluate RSD with Llama 2 and OPT models, showing that RSD outperforms the baseline methods, consistently for fixed draft sequence length and in most cases for fixed computational budgets at LLM.

Recursive Speculative Decoding: Accelerating LLM Inference via Sampling Without Replacement

TL;DR

Recursive Speculative Decoding is presented, a novel tree-based method that samples draft tokens without replacement and maximizes the diversity of the tree that outperforms the baseline methods, consistently for fixed draft sequence length and in most cases for fixed computational budgets at LLM.

Abstract

Speculative decoding is an inference-acceleration method for large language models (LLMs) where a small language model generates a draft-token sequence which is further verified by the target LLM in parallel. Recent works have advanced this method by establishing a draft-token tree, achieving superior performance over a single-sequence speculative decoding. However, those works independently generate tokens at each level of the tree, not leveraging the tree's entire diversifiability. Besides, their empirical superiority has been shown for fixed length of sequences, implicitly granting more computational resource to LLM for the tree-based methods. None of the existing works has conducted empirical studies with fixed target computational budgets despite its importance to resource-bounded devices. We present Recursive Speculative Decoding (RSD), a novel tree-based method that samples draft tokens without replacement and maximizes the diversity of the tree. During RSD's drafting, the tree is built by either Gumbel-Top- trick that draws tokens without replacement in parallel or Stochastic Beam Search that samples sequences without replacement while early-truncating unlikely draft sequences and reducing the computational cost of LLM. We empirically evaluate RSD with Llama 2 and OPT models, showing that RSD outperforms the baseline methods, consistently for fixed draft sequence length and in most cases for fixed computational budgets at LLM.
Paper Structure (30 sections, 4 theorems, 28 equations, 9 figures, 54 tables, 9 algorithms)

This paper contains 30 sections, 4 theorems, 28 equations, 9 figures, 54 tables, 9 algorithms.

Key Result

Theorem 3.1

For the random variable $Z\in\mathcal{X}$ in eq:recursive_rejection_sampling,

Figures (9)

  • Figure 1: Acceptance rates for multi-round speculative decoding, K-SEQ, OTM and recursive rejection sampling are given when $\mathrm{Ber}(p)$ and $\mathrm{Ber}(q)$ are draft and target distributions, respectively, and two tokens are proposed by the draft model ($K=2$).
  • Figure 2: We describe the entire process of RSD with Stochastic Beam Search (RSD-S); the difference between RSD-S and RSD with Constant branching factors (RSD-C) lies at the method of constructing the draft-token tree. Draft tokens the tree are sampled in parallel at each level and auto-regressively across levels, while Stochastic Beam Search samples sequences without replacement at each tree level. The established draft-token tree is then processed by the target model in parallel, which lets us acquire the token-wise target model probabilities. Finally, recursive rejection sampling (for sampling-without-replacement distribution) is applied to each level of the tree, recovering the sequence generation distribution of the target model.
  • Figure 3: We describe examples of constructing draft-token trees with the (maximum) draft length equal to 3; (a) The tree constructed by RSD-C with branching factors $\mathbf{b}=(3, 2, 1)$ is given; (b) we depict the tree constructed by RSD-S with beamwidth $W=3$, where edges are determined via Stochastic Beam Search.
  • Figure 4: Block efficiency, MBSU, token rate and accuracy for various lengths ($2,3,4,5$) of draft sequences are given. We consider two target models, Llama 2-70B and Llama 2-Chat-70B, each of which has a corresponding smaller draft model for speculative decoding. All results are normalized by the corresponding numbers from auto-regressive decoding. RSD-S always outperforms SD, SpecTr and RSD-C. All methods including auto-regressive decoding show similar accuracy for WMT and XSum.
  • Figure 5: Block efficiency, MBSU, token rate and accuracy for various target computational budgets (the numbers $6, 10, 14, 21, 30$ of draft tokens processed at the target model) are given. We consider two target models, Llama 2-70B and Llama 2-Chat-70B, each of which has a corresponding smaller draft model for speculative decoding. All results are normalized by the corresponding numbers from auto-regressive decoding. RSD-S outperforms SD, SpecTr and RSD-C in the majority of cases. All methods including auto-regressive decoding show similar accuracy for both WMT and XSum.
  • ...and 4 more figures

Theorems & Definitions (8)

  • Theorem 3.1: Recursive rejection sampling recovers target distribution
  • proof
  • Theorem 3.2: Tokens from the same sequence follow sampling without replacement in RSD-S
  • proof
  • Theorem 3.1: Recursive rejection sampling recovers target distribution
  • proof
  • Theorem 3.2: Tokens from the same sequence follow sampling without replacement in RSD-S
  • proof