Table of Contents
Fetching ...

Direct Alignment of Draft Model for Speculative Decoding with Chat-Fine-Tuned LLMs

Raghavv Goel, Mukul Gagrani, Wonseok Jeon, Junyoung Park, Mingu Lee, Christopher Lott

TL;DR

This work addresses the slowdown of large language models by memory bandwidth in auto-regressive decoding and proposes speculative decoding with a compact draft model aligned to a target chat model. It introduces a three-phase pipeline—draft pretraining, distillation-data generation from the target, and finetuning with a white-box distillation objective—and a TVD++ loss based on policy-gradient variance reduction to improve alignment. The authors train Llama 2 Chat Drafter 115M (1.64% of the 7B target) and demonstrate up to $2.3\times$ block efficiency and $2.4\times$ speed-up across tasks like open-ended generation and summarization, without extra task-specific fine-tuning. This approach enables more efficient edge inference for chat-capable models and reduces the need for large draft-model availability by directly aligning the draft with the target.

Abstract

Text generation with Large Language Models (LLMs) is known to be memory bound due to the combination of their auto-regressive nature, huge parameter counts, and limited memory bandwidths, often resulting in low token rates. Speculative decoding has been proposed as a solution for LLM inference acceleration. However, since draft models are often unavailable in the modern open-source LLM families, e.g., for Llama 2 7B, training a high-quality draft model is required to enable inference acceleration via speculative decoding. In this paper, we propose a simple draft model training framework for direct alignment to chat-capable target models. With the proposed framework, we train Llama 2 Chat Drafter 115M, a draft model for Llama 2 Chat 7B or larger, with only 1.64\% of the original size. Our training framework only consists of pretraining, distillation dataset generation, and finetuning with knowledge distillation, with no additional alignment procedure. For the finetuning step, we use instruction-response pairs generated by target model for distillation in plausible data distribution, and propose a new Total Variation Distance++ (TVD++) loss that incorporates variance reduction techniques inspired from the policy gradient method in reinforcement learning. Our empirical results show that Llama 2 Chat Drafter 115M with speculative decoding achieves up to 2.3 block efficiency and 2.4$\times$ speed-up relative to autoregressive decoding on various tasks with no further task-specific fine-tuning.

Direct Alignment of Draft Model for Speculative Decoding with Chat-Fine-Tuned LLMs

TL;DR

This work addresses the slowdown of large language models by memory bandwidth in auto-regressive decoding and proposes speculative decoding with a compact draft model aligned to a target chat model. It introduces a three-phase pipeline—draft pretraining, distillation-data generation from the target, and finetuning with a white-box distillation objective—and a TVD++ loss based on policy-gradient variance reduction to improve alignment. The authors train Llama 2 Chat Drafter 115M (1.64% of the 7B target) and demonstrate up to block efficiency and speed-up across tasks like open-ended generation and summarization, without extra task-specific fine-tuning. This approach enables more efficient edge inference for chat-capable models and reduces the need for large draft-model availability by directly aligning the draft with the target.

Abstract

Text generation with Large Language Models (LLMs) is known to be memory bound due to the combination of their auto-regressive nature, huge parameter counts, and limited memory bandwidths, often resulting in low token rates. Speculative decoding has been proposed as a solution for LLM inference acceleration. However, since draft models are often unavailable in the modern open-source LLM families, e.g., for Llama 2 7B, training a high-quality draft model is required to enable inference acceleration via speculative decoding. In this paper, we propose a simple draft model training framework for direct alignment to chat-capable target models. With the proposed framework, we train Llama 2 Chat Drafter 115M, a draft model for Llama 2 Chat 7B or larger, with only 1.64\% of the original size. Our training framework only consists of pretraining, distillation dataset generation, and finetuning with knowledge distillation, with no additional alignment procedure. For the finetuning step, we use instruction-response pairs generated by target model for distillation in plausible data distribution, and propose a new Total Variation Distance++ (TVD++) loss that incorporates variance reduction techniques inspired from the policy gradient method in reinforcement learning. Our empirical results show that Llama 2 Chat Drafter 115M with speculative decoding achieves up to 2.3 block efficiency and 2.4 speed-up relative to autoregressive decoding on various tasks with no further task-specific fine-tuning.
Paper Structure (13 sections, 2 theorems, 4 equations, 3 figures, 1 table)

This paper contains 13 sections, 2 theorems, 4 equations, 3 figures, 1 table.

Key Result

Lemma 1

The gradient of total variation distance $\mathrm{TVD}(p_\theta, q)$ between the draft model $p_\theta$ and the target model $q$ w.r.t. the draft model parameter $\theta$ is equal to $\nabla_\theta \mathrm{TVD}(p_\theta, q) = \mathbb{E}_{X\sim p_\theta} \left[ \nabla_\theta \log p_\theta(X)(-r(X)) \

Figures (3)

  • Figure 1: Draft models evaluated on Memory-Bound Speed-Up (MBSU) and token-rate (realtive to auto-regressive generation) for multiple tasks (Dolly, CNN-dailymail, XSum) with draft lengths (3, 5) and training losses (KLD, TVD, TVD++);
  • Figure 2: Draft models evaluated on block efficiency ($\gamma=3$) over different checkpoints (ckpt) across fine-tuning stage, showing improvement over the base draft model, for multiple tasks (Dolly, CNN-DM, XSum) and training losses (KLD, TVD, TVD++)
  • Figure 3: Block efficiency for WMT18-DeEn results with multiple draft models are described.

Theorems & Definitions (4)

  • Lemma 1
  • proof
  • Lemma \ref{lemma:TVD_RL}
  • proof