Table of Contents
Fetching ...

Faster Cascades via Speculative Decoding

Harikrishna Narasimhan, Wittawat Jitkrittum, Ankit Singh Rawat, Seungyeon Kim, Neha Gupta, Aditya Krishna Menon, Sanjiv Kumar

TL;DR

This paper describes the optimal deferral rule for speculative cascades, and employs a plug-in approximation to the optimal rule, and shows that this approach yields better cost quality trade-offs than cascading and speculative decoding baselines.

Abstract

Cascades and speculative decoding are two common approaches to improving language models' inference efficiency. Both approaches involve interleaving models of different sizes, but via fundamentally distinct mechanisms: cascades employ a deferral rule that invokes the larger model only for "hard" inputs, while speculative decoding uses speculative execution to primarily invoke the larger model in parallel verification mode. These mechanisms offer different benefits: empirically, cascades offer better cost-quality trade-offs, often even outperforming the large model, while theoretically, speculative decoding offers a guarantee of quality-neutrality. In this paper, we leverage the best of both these approaches by designing new speculative cascading techniques that implement their deferral rule through speculative execution. We characterize the optimal deferral rule for our speculative cascades, and employ a plug-in approximation to the optimal rule. Experiments with Gemma and T5 models on a range of language benchmarks show that our approach yields better cost quality trade-offs than cascading and speculative decoding baselines.

Faster Cascades via Speculative Decoding

TL;DR

This paper describes the optimal deferral rule for speculative cascades, and employs a plug-in approximation to the optimal rule, and shows that this approach yields better cost quality trade-offs than cascading and speculative decoding baselines.

Abstract

Cascades and speculative decoding are two common approaches to improving language models' inference efficiency. Both approaches involve interleaving models of different sizes, but via fundamentally distinct mechanisms: cascades employ a deferral rule that invokes the larger model only for "hard" inputs, while speculative decoding uses speculative execution to primarily invoke the larger model in parallel verification mode. These mechanisms offer different benefits: empirically, cascades offer better cost-quality trade-offs, often even outperforming the large model, while theoretically, speculative decoding offers a guarantee of quality-neutrality. In this paper, we leverage the best of both these approaches by designing new speculative cascading techniques that implement their deferral rule through speculative execution. We characterize the optimal deferral rule for our speculative cascades, and employ a plug-in approximation to the optimal rule. Experiments with Gemma and T5 models on a range of language benchmarks show that our approach yields better cost quality trade-offs than cascading and speculative decoding baselines.
Paper Structure (36 sections, 8 theorems, 63 equations, 12 figures, 3 tables, 6 algorithms)

This paper contains 36 sections, 8 theorems, 63 equations, 12 figures, 3 tables, 6 algorithms.

Key Result

Lemma 1

The minimizer of equation eq:seq-def-risk is of the form:

Figures (12)

  • Figure 1: Plots of quality as a function of the number of deferrals to the larger model divided by the total number of generated tokens for cascades constructed from T5 models (under temperature sampling with $T = 1$). The left-most point represents the small model and the right-most represents the large model. We compare token-level cascades constructed with Chow's rule ( Chow) and an oracle deferral rule ( Diff), and speculative decoding with block size $\gamma = 5$. With a cascade, each call to the large model yields exactly one token, whereas with speculative decoding, a single call scores $\gamma$ draft tokens in parallel. While speculative decoding matches the quality of the large model (see dashed horizontal line), the oracle deferral rule yields significantly better quality on a range of deferral rates; this however comes at the cost of higher number of deferrals to the large model.
  • Figure 2: Reduction in latency from different methods ($T=1, \gamma=5$) when matching the quality of the large model (cols 2--7), and the best quality metric when matching each method yields without exceeding the latency of the large model (cols 8--13). Quality is measured in terms of the BLEU for WMT and ROUGE-2 for XSum and CNNDM. See Figure \ref{['fig:tradeoffs']} for $T=0.5$ and §\ref{['app:T5-greedy']} for $T=0$.
  • Figure 3: Plots of quality vs. rejection rate for methods that interleave Gemma 2B with Gemma 27B ($\gamma=1$). We use instruction-tuned models; for MBPP we report additional results with pre-trained models. See §\ref{['app:expts-token-specific']} for remaining plots, comparison to (\ref{['eq:sample-dep-01-plugin-v1']}--\ref{['eq:sample-dep-01-plugin-v2']}) and results on 2B $\rightarrow$ 9B cascades.
  • Figure 4: Plots of quality vs. latency for T5 models with temperature $T=1$ and block size $\gamma=5$. We include T5 plots not included in Figure \ref{['fig:tradeoffs']} in the main text. Each method interleaves T5-small with T-large (or T5-XL). The $x$-axis tracks the latency relative to that of calling the large model on all inputs. The horizontal dotted line denotes the quality of the large model.
  • Figure 5: Plots of quality vs. latency for T5 models with greedy decoding with temperature $T=0$ and block size $\gamma=5$. Each method interleaves T5-small with T-large (or T5-XL). The $x$-axis tracks the latency relative to that of calling the large model on all inputs. The horizontal dotted line denotes the quality of the large model.
  • ...and 7 more figures

Theorems & Definitions (16)

  • Lemma 1: Optimal deferral for token-level cascades jitkrittum2024does
  • Remark 1: Oracle deferral rules
  • Lemma 2
  • Remark 2: Exact implementation of oracle deferral rule Diff
  • Lemma 3
  • Lemma 4: Optimal deferral for speculative cascades
  • Lemma 5: Regret bound for $\hat{r}_{\rm\tt OPT}$
  • proof
  • proof
  • proof
  • ...and 6 more