Table of Contents
Fetching ...

Jakiro: Boosting Speculative Decoding with Decoupled Multi-Head via MoE

Haiduo Huang, Fuwei Yang, Zhenhua Liu, Yixing Xu, Jinze Li, Yang Liu, Xuanwu Yin, Dong Li, Pengju Ren, Emad Barsoum

TL;DR

Jakiro tackles the bottleneck of speculative decoding by introducing dynamic decoupling through MoE heads to diversify draft-token predictions while preserving speed. It pairs a hybrid autoregressive-first and parallel-last-token inference with a feature-level contrastive mechanism to improve draft quality without adding latency. The approach achieves state-of-the-art speedups across Vicuna, LLaMA2-Chat, and LLaMA3-Instruct on multiple benchmarks, with strong performance in non-greedy settings and robustness across hardware. This work suggests a practical route to more scalable LLM inference by combining MoE-based diversification, parallel decoding, and contrastive refinement.

Abstract

Speculative decoding (SD) accelerates large language model inference by using a smaller draft model to predict multiple tokens, which are then verified in parallel by the larger target model. However, the limited capacity of the draft model often necessitates tree-based sampling to improve prediction accuracy, where multiple candidates are generated at each step. We identify a key limitation in this approach: the candidates at the same step are derived from the same representation, limiting diversity and reducing overall effectiveness. To address this, we propose Jakiro, leveraging Mixture of Experts (MoE), where independent experts generate diverse predictions, effectively decoupling correlations among candidates. Furthermore, we introduce a hybrid inference strategy, combining autoregressive decoding for initial tokens with parallel decoding for subsequent stages, and enhance the latter with contrastive mechanism in features to improve accuracy. Our method significantly boosts prediction accuracy and achieves higher inference speedups. Extensive experiments across diverse models validate the effectiveness and robustness of our approach, establishing a new SOTA in speculative decoding. Our codes are available at https://github.com/haiduo/Jakiro.

Jakiro: Boosting Speculative Decoding with Decoupled Multi-Head via MoE

TL;DR

Jakiro tackles the bottleneck of speculative decoding by introducing dynamic decoupling through MoE heads to diversify draft-token predictions while preserving speed. It pairs a hybrid autoregressive-first and parallel-last-token inference with a feature-level contrastive mechanism to improve draft quality without adding latency. The approach achieves state-of-the-art speedups across Vicuna, LLaMA2-Chat, and LLaMA3-Instruct on multiple benchmarks, with strong performance in non-greedy settings and robustness across hardware. This work suggests a practical route to more scalable LLM inference by combining MoE-based diversification, parallel decoding, and contrastive refinement.

Abstract

Speculative decoding (SD) accelerates large language model inference by using a smaller draft model to predict multiple tokens, which are then verified in parallel by the larger target model. However, the limited capacity of the draft model often necessitates tree-based sampling to improve prediction accuracy, where multiple candidates are generated at each step. We identify a key limitation in this approach: the candidates at the same step are derived from the same representation, limiting diversity and reducing overall effectiveness. To address this, we propose Jakiro, leveraging Mixture of Experts (MoE), where independent experts generate diverse predictions, effectively decoupling correlations among candidates. Furthermore, we introduce a hybrid inference strategy, combining autoregressive decoding for initial tokens with parallel decoding for subsequent stages, and enhance the latter with contrastive mechanism in features to improve accuracy. Our method significantly boosts prediction accuracy and achieves higher inference speedups. Extensive experiments across diverse models validate the effectiveness and robustness of our approach, establishing a new SOTA in speculative decoding. Our codes are available at https://github.com/haiduo/Jakiro.

Paper Structure

This paper contains 19 sections, 3 equations, 7 figures, 6 tables.

Figures (7)

  • Figure 1: Comparison of different speculative decoding methods.
  • Figure 2: Speedup ratio of Vicuna, LLaMA2-chat, and LLaMA3-instruct models inference latency on the MT-bench for non-greedy (Temperature=1) settings. The above reproduction results are based on the open-source code from the original paper and are averaged over four inference runs on an A100-40G GPU. In this paper, we only compare with speculative sampling based methods that do not need to finetune the backbone models, ensuring the output text distribution remains constant.
  • Figure 3: Comparison of different building methods of draft tree.
  • Figure 4: The building process of tree attention mask mechanism.
  • Figure 5: Efficient integration of contrastive mechanism.
  • ...and 2 more figures