Table of Contents
Fetching ...

Reject Only Critical Tokens: Pivot-Aware Speculative Decoding

Amir Ziashahabi, Yavuz Faruk Bakman, Duygu Nur Yaldiz, Mostafa El-Khamy, Sai Praneeth Karimireddy, Salman Avestimehr

TL;DR

This paper tackles the slow generation problem in large language models by questioning the necessity of exact distribution matching in Speculative Decoding (SD). It introduces Pivot-Aware Speculative Decoding (PAD), which relaxes SD to maximize expected utility, identifying pivot tokens whose rejection would harm final task performance and using a lightweight classifier to avoid them. The method combines Monte Carlo rollouts, a binary utility formulation, an LLM-as-judge sanity check, and an inference-time pivot predictor to boost draft acceptance. Experiments across math and programming tasks show up to ~2.5x speedups with utility-preserving outputs, demonstrating PAD's practical value and generality across domains.

Abstract

Speculative Decoding (SD) ensures that the output matches the target model's distribution exactly. However, we argue that this distribution matching requirement is too stringent and results in unnecessarily low acceptance rates, limiting potential speedups. Instead, we advocate a reformulation of the decoding objective: the proposed decoding strategy should match the expected utility, i.e., the task-specific performance, of the target model. This perspective also aligns better with real-world use cases of LLMs, where utility (e.g., code correctness, factual accuracy) is often more important than sampling distribution. Based on this reformulation, we propose a novel decoding strategy: Pivot-Aware Speculative Decoding, which rejects only those tokens that would lead to a utility drop in the final output. We refer to these critical tokens as pivot tokens. We propose a method for labeling tokens as pivotal or non-pivotal and train a lightweight classifier to detect them. This method can be viewed as a relaxed version of standard SD, which offers much higher acceptance while preserving utility. We evaluate our method across various datasets, demonstrating that we can achieve up to $2.5\times$ speedup with comparable utility. Source code is available at https://github.com/amir-zsh/PAD.

Reject Only Critical Tokens: Pivot-Aware Speculative Decoding

TL;DR

This paper tackles the slow generation problem in large language models by questioning the necessity of exact distribution matching in Speculative Decoding (SD). It introduces Pivot-Aware Speculative Decoding (PAD), which relaxes SD to maximize expected utility, identifying pivot tokens whose rejection would harm final task performance and using a lightweight classifier to avoid them. The method combines Monte Carlo rollouts, a binary utility formulation, an LLM-as-judge sanity check, and an inference-time pivot predictor to boost draft acceptance. Experiments across math and programming tasks show up to ~2.5x speedups with utility-preserving outputs, demonstrating PAD's practical value and generality across domains.

Abstract

Speculative Decoding (SD) ensures that the output matches the target model's distribution exactly. However, we argue that this distribution matching requirement is too stringent and results in unnecessarily low acceptance rates, limiting potential speedups. Instead, we advocate a reformulation of the decoding objective: the proposed decoding strategy should match the expected utility, i.e., the task-specific performance, of the target model. This perspective also aligns better with real-world use cases of LLMs, where utility (e.g., code correctness, factual accuracy) is often more important than sampling distribution. Based on this reformulation, we propose a novel decoding strategy: Pivot-Aware Speculative Decoding, which rejects only those tokens that would lead to a utility drop in the final output. We refer to these critical tokens as pivot tokens. We propose a method for labeling tokens as pivotal or non-pivotal and train a lightweight classifier to detect them. This method can be viewed as a relaxed version of standard SD, which offers much higher acceptance while preserving utility. We evaluate our method across various datasets, demonstrating that we can achieve up to speedup with comparable utility. Source code is available at https://github.com/amir-zsh/PAD.

Paper Structure

This paper contains 20 sections, 1 theorem, 19 equations, 3 figures, 1 table.

Key Result

Lemma 1

If $f_{\text{pivot}}$ has $100\%$ recall on pivot tokens (i.e., it never labels a pivot as non-pivot), then $p_{\text{PAD}}$ satisfies Definition def:utility-preserving-decoding with $\epsilon=0$.

Figures (3)

  • Figure 1: To match the target model's distribution, SD rejects many tokens that a draft model proposes (shown in blue). Most of these rejections are unnecessary, and fixing a single token (2$\rightarrow$1) is enough to recover the correct answer.
  • Figure 2: Pivot-Aware Speculative Decoding (PAD). (a) Dataset generation: label SD-rejected draft tokens via target-model rollouts with an LLM-as-judge sanity check; (b) Training: fit a pivot classifier on target-side features (layer-$\ell$ hidden states, logits, entropy); (c) Inference: accept tokens if standard SD accepts them, or if the classifier predicts non-pivot.
  • Figure 3: ROC curve for the pivot classifier on the held-out test set.

Theorems & Definitions (5)

  • Definition 1: Utility
  • Definition 2: $\epsilon$-Utility preserving decoding
  • Definition 3: Pivot Token
  • Lemma 1: Rejecting only pivot tokens preserves utility
  • proof : Proof of Lemma \ref{['lem:recall']}