Table of Contents
Fetching ...

A Theoretical Perspective for Speculative Decoding Algorithm

Ming Yin, Minshuo Chen, Kaixuan Huang, Mengdi Wang

TL;DR

The results reveal the fundamental connections between different components of LLMs via total variation distances and show how they jointly affect the efficiency of decoding algorithms.

Abstract

Transformer-based autoregressive sampling has been the major bottleneck for slowing down large language model inferences. One effective way to accelerate inference is \emph{Speculative Decoding}, which employs a small model to sample a sequence of draft tokens and a large model to validate. Given its empirical effectiveness, the theoretical understanding of Speculative Decoding is falling behind. This paper tackles this gap by conceptualizing the decoding problem via markov chain abstraction and studying the key properties, \emph{output quality and inference acceleration}, from a theoretical perspective. Our analysis covers the theoretical limits of speculative decoding, batch algorithms, and output quality-inference acceleration tradeoffs. Our results reveal the fundamental connections between different components of LLMs via total variation distances and show how they jointly affect the efficiency of decoding algorithms.

A Theoretical Perspective for Speculative Decoding Algorithm

TL;DR

The results reveal the fundamental connections between different components of LLMs via total variation distances and show how they jointly affect the efficiency of decoding algorithms.

Abstract

Transformer-based autoregressive sampling has been the major bottleneck for slowing down large language model inferences. One effective way to accelerate inference is \emph{Speculative Decoding}, which employs a small model to sample a sequence of draft tokens and a large model to validate. Given its empirical effectiveness, the theoretical understanding of Speculative Decoding is falling behind. This paper tackles this gap by conceptualizing the decoding problem via markov chain abstraction and studying the key properties, \emph{output quality and inference acceleration}, from a theoretical perspective. Our analysis covers the theoretical limits of speculative decoding, batch algorithms, and output quality-inference acceleration tradeoffs. Our results reveal the fundamental connections between different components of LLMs via total variation distances and show how they jointly affect the efficiency of decoding algorithms.

Paper Structure

This paper contains 26 sections, 14 theorems, 88 equations, 5 figures, 2 tables, 4 algorithms.

Key Result

Theorem 1

We have the following two results for Speculative Decoding. (1) We define random variables $R_n\in\{0,1\}$ that indicates whether the $n$-th token is rejected (with $1$ being rejected). Here rejection means Line 7 of Algorithm alg:speculative_decoding is executed. Then, the total number of rejection (2) The output distributions of Algorithm alg:speculative_decoding and the target model $q$ are ide

Figures (5)

  • Figure 1: Left: Standard Auto-Regressive Decoding (Algorithm \ref{['alg:auto_sampling']}) v.s. Right: Speculative Decoding (Algorithm \ref{['alg:speculative_decoding']}), where a large model is used to validate the responses of the small model.
  • Figure 2: The numeric instance in this figure chooses $p,q$ to be nonstationary Markov Chains with horizon $T=50$. Left $(a)$: A simulation of Speculative Decoding. The green line is the empirical average rejections among $100 N$ runs and the orange line the theoretical value computed via Theorem \ref{['thm:exp_rej']}. Middle $(b)$: Batch Speculative Decoding simulations with batch $M=4,5$. The green/purple lines are the empirical average rejections among $100 N$ runs and the orange/pink lines are the theoretical values computed via Theorem \ref{['thm:exp_rej_batch']}. Right $(c)$: The scaling law of expected rejections for Batch SD as a function of $M$. It converges to a limit as $M\rightarrow \infty$.
  • Figure 3: Left: Batch Speculative Decoding. Right: Batch Improvement vs. Batch size $M$. Upper: Bernoulli distributions with $q=Ber(0.5)$. Lower: $p\sim \text{Unif}(V)$, $q\sim \text{Unif}(V')$ with $r=V/V'$.
  • Figure 4: Left $(a)$: The Pareto Front between Rejection Probability$\mathbb{P}^{\mathcal{A}}(\text{reject})$ vs. Distribution bias${\sf TV}[\mathbb{P}^{\mathcal{A}},q]$. For a given rejection probability, the black line denotes the optimal deviation $\text{Loss}_{\sf TV}^*$. Middle $(b)$ and Right $(c)$: A numeric example. In the plot, the over acceptance $\epsilon$'s are set as positive constants that define $b(x)=\min\{1,\frac{q(x)+\epsilon}{p(x)}\}$.
  • Figure 5: A simulation of (Batch) Speculative Decoding with horizon $T=100$.

Theorems & Definitions (31)

  • Definition 1
  • Remark 1
  • Theorem 1
  • Remark 2
  • Theorem 2: Instance-dependent Rejection Lower Bound
  • Theorem 3: Unbiasedness and efficiency of batch SD
  • Theorem 4: Optimal solution to optimization \ref{['eqn:objj']}
  • Theorem 5: Pareto front
  • Theorem 6: Restatement of the first part of Theorem \ref{['thm:exp_rej']}
  • proof
  • ...and 21 more