Table of Contents
Fetching ...

CORAL: Learning Consistent Representations across Multi-step Training with Lighter Speculative Drafter

Yepeng Weng, Dianwen Mei, Huishi Qiu, Xujie Chen, Li Liu, Jiang Tian, Zhongchao Shi

TL;DR

CORAL tackles training-inference misalignment in speculative decoding by introducing Cross-Step Representation Alignment (CSRA), a contrastive objective that enforces consistency of draft representations across multiple training steps, and a lightweight LM Head Router that selectively activates subsets of the LM head to reduce drafting latency for large-vocabulary models. The approach also provides practical speedup estimation and a two-stage training protocol for the router, enabling a plug-and-play improvement to existing speculative decoders. Empirical results across Llama3, Llama2, and Qwen models on MT-bench, HumanEval, and GSM8K show 2.50x–4.07x speedups over vanilla decoding, outperforming state-of-the-art methods like EAGLE-2 and HASS. The combination of CSRA and routing yields higher stability and faster drafting with manageable training and deployment costs, offering a practical path to faster inference for modern LLMs with large vocabularies.

Abstract

Speculative decoding is a powerful technique that accelerates Large Language Model (LLM) inference by leveraging a lightweight speculative draft model. However, existing designs suffers in performance due to misalignment between training and inference. Recent methods have tried to solve this issue by adopting a multi-step training strategy, but the complex inputs of different training steps make it harder for the draft model to converge. To address this, we propose CORAL, a novel framework that improves both accuracy and efficiency in speculative drafting. CORAL introduces Cross-Step Representation Alignment, a method that enhances consistency across multiple training steps, significantly improving speculative drafting performance. Additionally, we identify the LM head as a major bottleneck in the inference speed of the draft model. We introduce a weight-grouping mechanism that selectively activates a subset of LM head parameters during inference, substantially reducing the latency of the draft model. We evaluate CORAL on three LLM families and three benchmark datasets, achieving speedup ratios of 2.50x-4.07x, outperforming state-of-the-art methods such as EAGLE-2 and HASS. Our results demonstrate that CORAL effectively mitigates training-inference misalignment and delivers significant speedup for modern LLMs with large vocabularies.

CORAL: Learning Consistent Representations across Multi-step Training with Lighter Speculative Drafter

TL;DR

CORAL tackles training-inference misalignment in speculative decoding by introducing Cross-Step Representation Alignment (CSRA), a contrastive objective that enforces consistency of draft representations across multiple training steps, and a lightweight LM Head Router that selectively activates subsets of the LM head to reduce drafting latency for large-vocabulary models. The approach also provides practical speedup estimation and a two-stage training protocol for the router, enabling a plug-and-play improvement to existing speculative decoders. Empirical results across Llama3, Llama2, and Qwen models on MT-bench, HumanEval, and GSM8K show 2.50x–4.07x speedups over vanilla decoding, outperforming state-of-the-art methods like EAGLE-2 and HASS. The combination of CSRA and routing yields higher stability and faster drafting with manageable training and deployment costs, offering a practical path to faster inference for modern LLMs with large vocabularies.

Abstract

Speculative decoding is a powerful technique that accelerates Large Language Model (LLM) inference by leveraging a lightweight speculative draft model. However, existing designs suffers in performance due to misalignment between training and inference. Recent methods have tried to solve this issue by adopting a multi-step training strategy, but the complex inputs of different training steps make it harder for the draft model to converge. To address this, we propose CORAL, a novel framework that improves both accuracy and efficiency in speculative drafting. CORAL introduces Cross-Step Representation Alignment, a method that enhances consistency across multiple training steps, significantly improving speculative drafting performance. Additionally, we identify the LM head as a major bottleneck in the inference speed of the draft model. We introduce a weight-grouping mechanism that selectively activates a subset of LM head parameters during inference, substantially reducing the latency of the draft model. We evaluate CORAL on three LLM families and three benchmark datasets, achieving speedup ratios of 2.50x-4.07x, outperforming state-of-the-art methods such as EAGLE-2 and HASS. Our results demonstrate that CORAL effectively mitigates training-inference misalignment and delivers significant speedup for modern LLMs with large vocabularies.

Paper Structure

This paper contains 25 sections, 9 equations, 6 figures, 9 tables.

Figures (6)

  • Figure 1: Speedup ratios of different methods on Llama3-8B and Qwen2.5-7B at temperature=0, averaging on MT-bench, HumanEval, and GSM8K datasets. We present full results in Table \ref{['tab:full_result']} and this chart is only a subset of all comparisons.
  • Figure 2: Parameters and latencies of Llama3-8B, Llama2-7B, Qwen2.5-7B draft model. For a model with large vocabulary, the LM head takes the majority of the drafting latency.
  • Figure 3: Demonstration of EAGLE training / inference and multi-step training with CSRA. $f$ denotes feature and $e$ denotes embedding. Superscripts indicate the source of the variable, with $t$ and $d$ denoting the target model and draft model. Subscripts index the position of a feature or embedding. For example, $f_{3}^{t}$ means the feature in position 3 and comes from the target model. For multi-step training, we use apostrophes to distinguish the outputs of different training steps. Specifically, we denote the output feature of step 1 as $f^{d}$, and for step 2 and 3 we use $f^{d'}$ and $f^{d"}$, respectively. Compared to HASS, CSRA introduces additional constraints on feature consistency. The training target is applied at each step, and we only illustrate it once for the sake of clarity.
  • Figure 4: Comparison of EAGLE training, HASS training and CSRA. Here $\bigcirc$ denotes training target, $\bigtriangleup$ denotes output features from different steps. Triangles filled with darker colors represent the first step's output. Different colors represent outputs or targets of different positions. Optimization direction is marked as $\to$, and the dashed $\leftrightarrow$ means repulsion.
  • Figure 5: Demonstration of LM head router in draft model. With the router, we only output probabilities of one or multiple subsets of vocabulary.
  • ...and 1 more figures