Table of Contents
Fetching ...

Judge Decoding: Faster Speculative Sampling Requires Going Beyond Model Alignment

Gregor Bachmann, Sotiris Anagnostidis, Albert Pumarola, Markos Georgopoulos, Artsiom Sanakoyeu, Yuming Du, Edgar Schönfeld, Ali Thabet, Jonas Kohler

TL;DR

This paper identifies a fundamental bottleneck in speculative decoding: alignment-based verification often rejects correct but non-aligned tokens, limiting speedups for large language models. It proposes Judge Decoding, which adds a lightweight judge head trained on target embeddings to assess token correctness, enabling more tokens to be accepted without sacrificing target quality. A 16k/8k-parameter linear head is trained on about 30k tokens to predict token correctness, and the acceptance decision uses a simple threshold on the judge output while OR-ing with standard verification. Empirically, Judge Decoding achieves up to 9× speedups on Llama-3.1-70B/405B and up to 141 tokens/s in optimized inference, while preserving performance on GSM8K, MT-Bench, and HumanEval, with demonstrations across out-of-distribution tasks showing partial generalization. This approach demonstrates a practical, data-efficient path to accelerating inference by shifting verification from strict alignment to correctness judgments, albeit without the formal guarantees of traditional speculative decoding.

Abstract

The performance of large language models (LLMs) is closely linked to their underlying size, leading to ever-growing networks and hence slower inference. Speculative decoding has been proposed as a technique to accelerate autoregressive generation, leveraging a fast draft model to propose candidate tokens, which are then verified in parallel based on their likelihood under the target model. While this approach guarantees to reproduce the target output, it incurs a substantial penalty: many high-quality draft tokens are rejected, even when they represent objectively valid continuations. Indeed, we show that even powerful draft models such as GPT-4o, as well as human text cannot achieve high acceptance rates under the standard verification scheme. This severely limits the speedup potential of current speculative decoding methods, as an early rejection becomes overwhelmingly likely when solely relying on alignment of draft and target. We thus ask the following question: Can we adapt verification to recognize correct, but non-aligned replies? To this end, we draw inspiration from the LLM-as-a-judge framework, which demonstrated that LLMs are able to rate answers in a versatile way. We carefully design a dataset to elicit the same capability in the target model by training a compact module on top of the embeddings to produce ``judgements" of the current continuation. We showcase our strategy on the Llama-3.1 family, where our 8b/405B-Judge achieves a speedup of 9x over Llama-405B, while maintaining its quality on a large range of benchmarks. These benefits remain present even in optimized inference frameworks, where our method reaches up to 141 tokens/s for 8B/70B-Judge and 129 tokens/s for 8B/405B on 2 and 8 H100s respectively.

Judge Decoding: Faster Speculative Sampling Requires Going Beyond Model Alignment

TL;DR

This paper identifies a fundamental bottleneck in speculative decoding: alignment-based verification often rejects correct but non-aligned tokens, limiting speedups for large language models. It proposes Judge Decoding, which adds a lightweight judge head trained on target embeddings to assess token correctness, enabling more tokens to be accepted without sacrificing target quality. A 16k/8k-parameter linear head is trained on about 30k tokens to predict token correctness, and the acceptance decision uses a simple threshold on the judge output while OR-ing with standard verification. Empirically, Judge Decoding achieves up to 9× speedups on Llama-3.1-70B/405B and up to 141 tokens/s in optimized inference, while preserving performance on GSM8K, MT-Bench, and HumanEval, with demonstrations across out-of-distribution tasks showing partial generalization. This approach demonstrates a practical, data-efficient path to accelerating inference by shifting verification from strict alignment to correctness judgments, albeit without the formal guarantees of traditional speculative decoding.

Abstract

The performance of large language models (LLMs) is closely linked to their underlying size, leading to ever-growing networks and hence slower inference. Speculative decoding has been proposed as a technique to accelerate autoregressive generation, leveraging a fast draft model to propose candidate tokens, which are then verified in parallel based on their likelihood under the target model. While this approach guarantees to reproduce the target output, it incurs a substantial penalty: many high-quality draft tokens are rejected, even when they represent objectively valid continuations. Indeed, we show that even powerful draft models such as GPT-4o, as well as human text cannot achieve high acceptance rates under the standard verification scheme. This severely limits the speedup potential of current speculative decoding methods, as an early rejection becomes overwhelmingly likely when solely relying on alignment of draft and target. We thus ask the following question: Can we adapt verification to recognize correct, but non-aligned replies? To this end, we draw inspiration from the LLM-as-a-judge framework, which demonstrated that LLMs are able to rate answers in a versatile way. We carefully design a dataset to elicit the same capability in the target model by training a compact module on top of the embeddings to produce ``judgements" of the current continuation. We showcase our strategy on the Llama-3.1 family, where our 8b/405B-Judge achieves a speedup of 9x over Llama-405B, while maintaining its quality on a large range of benchmarks. These benefits remain present even in optimized inference frameworks, where our method reaches up to 141 tokens/s for 8B/70B-Judge and 129 tokens/s for 8B/405B on 2 and 8 H100s respectively.

Paper Structure

This paper contains 34 sections, 5 equations, 13 figures, 1 table.

Figures (13)

  • Figure 1: Standard speculative decoding versus our judge decoding strategy for Llama-3.1-8B as draft and Llama-3.1-405B as target. Accepted (rejected) tokens are highlighted in green (red).
  • Figure 2: Left: Average number of generated tokens as a function of the number of draft tokens $M$ for Llama-8B/405B with standard and judge verification. Right: Number of accepted tokens on high-quality human text (top) and for both 8B/405B and 405B/8B (bottom), both standard SD.
  • Figure 3: Left: Standard SD and our judge decoding when GPT-4o is drafting and Llama-405B is verifying. Green denotes accepted and red rejected tokens. Right: Number of accepted tokens for GPT-4o as draft and Llama-405B as target for standard speculative and our judge verification.
  • Figure 4: Two examples from our dataset. We highlight the incorrect tokens in the wrong answer in red.
  • Figure 5: Left: Conditioning Llama-405B on wrong outputs. The part of the assistant response in red was forced, while parts in green were generated freely. Right: Judge illustration where $s_L$ is the last token from the context $\bm{s}$ and $c_1, \dots, c_M$ are candidate tokens. Orange denotes embeddings, green denotes the LM-head output and red denotes the produced judgements.
  • ...and 8 more figures