Table of Contents
Fetching ...

Optimized Multi-Token Joint Decoding with Auxiliary Model for LLM Inference

Zongyue Qin, Ziniu Hu, Zifan He, Neha Prakriya, Jason Cong, Yizhou Sun

TL;DR

This work tackles the costly autoregressive decoding of large language models by proposing MTJD, which generates tokens jointly using their joint distribution, and MTAD, an efficient approximation that leverages a smaller auxiliary model to draft tokens and verify them with the large model. The authors provide theoretical bounds showing that MTAD can approximate exact MTJD with bounded error and demonstrate practical gains across diverse models (Llama-2 and OPT 13B–70B) and tasks, including substantial perplexity reductions and improved downstream performance. Empirically, MTAD achieves about a 1.26–1.42x speed-up and reduced energy use (up to ~23–25%) compared with competitive speculative approaches, while MMTAD further improves efficiency by leveraging intermediate beams. Overall, the approach offers a principled path to faster, more accurate, and energy-efficient LLM inference through coordinated multi-token decoding and auxiliary-model-assisted verification.

Abstract

Large language models (LLMs) have achieved remarkable success across diverse tasks, yet their inference processes are hindered by substantial time and energy demands due to single-token generation at each decoding step. While previous methods such as speculative decoding mitigate these inefficiencies by producing multiple tokens per step, each token is still generated by its single-token distribution, thereby enhancing speed without improving effectiveness. In contrast, our work simultaneously enhances inference speed and improves the output effectiveness. We consider multi-token joint decoding (MTJD), which generates multiple tokens from their joint distribution at each iteration, theoretically reducing perplexity and enhancing task performance. However, MTJD suffers from the high cost of sampling from the joint distribution of multiple tokens. Inspired by speculative decoding, we introduce multi-token assisted decoding (MTAD), a novel framework designed to accelerate MTJD. MTAD leverages a smaller auxiliary model to approximate the joint distribution of a larger model, incorporating a verification mechanism that not only ensures the accuracy of this approximation, but also improves the decoding efficiency over conventional speculative decoding. Theoretically, we demonstrate that MTAD closely approximates exact MTJD with bounded error. Empirical evaluations using Llama-2 and OPT models ranging from 13B to 70B parameters across various tasks reveal that MTAD reduces perplexity by 21.2% and improves downstream performance compared to standard single-token sampling. Furthermore, MTAD achieves a 1.42x speed-up and consumes 1.54x less energy than conventional speculative decoding methods. These results highlight MTAD's ability to make multi-token joint decoding both effective and efficient, promoting more sustainable and high-performance deployment of LLMs.

Optimized Multi-Token Joint Decoding with Auxiliary Model for LLM Inference

TL;DR

This work tackles the costly autoregressive decoding of large language models by proposing MTJD, which generates tokens jointly using their joint distribution, and MTAD, an efficient approximation that leverages a smaller auxiliary model to draft tokens and verify them with the large model. The authors provide theoretical bounds showing that MTAD can approximate exact MTJD with bounded error and demonstrate practical gains across diverse models (Llama-2 and OPT 13B–70B) and tasks, including substantial perplexity reductions and improved downstream performance. Empirically, MTAD achieves about a 1.26–1.42x speed-up and reduced energy use (up to ~23–25%) compared with competitive speculative approaches, while MMTAD further improves efficiency by leveraging intermediate beams. Overall, the approach offers a principled path to faster, more accurate, and energy-efficient LLM inference through coordinated multi-token decoding and auxiliary-model-assisted verification.

Abstract

Large language models (LLMs) have achieved remarkable success across diverse tasks, yet their inference processes are hindered by substantial time and energy demands due to single-token generation at each decoding step. While previous methods such as speculative decoding mitigate these inefficiencies by producing multiple tokens per step, each token is still generated by its single-token distribution, thereby enhancing speed without improving effectiveness. In contrast, our work simultaneously enhances inference speed and improves the output effectiveness. We consider multi-token joint decoding (MTJD), which generates multiple tokens from their joint distribution at each iteration, theoretically reducing perplexity and enhancing task performance. However, MTJD suffers from the high cost of sampling from the joint distribution of multiple tokens. Inspired by speculative decoding, we introduce multi-token assisted decoding (MTAD), a novel framework designed to accelerate MTJD. MTAD leverages a smaller auxiliary model to approximate the joint distribution of a larger model, incorporating a verification mechanism that not only ensures the accuracy of this approximation, but also improves the decoding efficiency over conventional speculative decoding. Theoretically, we demonstrate that MTAD closely approximates exact MTJD with bounded error. Empirical evaluations using Llama-2 and OPT models ranging from 13B to 70B parameters across various tasks reveal that MTAD reduces perplexity by 21.2% and improves downstream performance compared to standard single-token sampling. Furthermore, MTAD achieves a 1.42x speed-up and consumes 1.54x less energy than conventional speculative decoding methods. These results highlight MTAD's ability to make multi-token joint decoding both effective and efficient, promoting more sustainable and high-performance deployment of LLMs.
Paper Structure (41 sections, 8 theorems, 36 equations, 7 figures, 11 tables, 1 algorithm)

This paper contains 41 sections, 8 theorems, 36 equations, 7 figures, 11 tables, 1 algorithm.

Key Result

Theorem 3.2

Assume at the $i$-th ($i=1,\ldots,N$) iteration, MTJD generates $\gamma_i$ tokens. Let $\Gamma_i$ denote the total number of tokens generated at the first $i$ iterations. Let $x_{1:\Gamma_N}$ denote the generated tokens. When $N \rightarrow \infty$ where $\bar{\gamma}$ is the expected number of $\gamma_i$, $\tilde{p}=\mathcal{T}\circ p$ represents how we sample the next $\gamma_i$ tokens from $p$

Figures (7)

  • Figure 1: Perplexity and Rouge-L score of the output when $\gamma_i=K$ for MTJD with OPT-125M and Llama-2-68M fine-tuned on ChatGPT-Prompts chatgpt-prompts dataset.
  • Figure 2: An example of MTAD's verification process. MTAD accepts the longest draft sub-sequence that passes verification based on joint likelihood.
  • Figure 3: Illustration of MMTAD: (a) All intermediate beams of beam sampling naturally form a tree. Vanilla MTAD only verify the output beam (yellow blocks), MMTAD verify all the beams. (b) MMTAD utilizes tree attention to efficiently compute the target likelihood of each beam. (c) MMTAD selects the longest accepted sequence with the highest target likelihood to return.
  • Figure 4: Performance of MMTAD when draft length $\gamma\in\{3,4,5,6,7,8,9,10\}$.
  • Figure 5: Performance of MMTAD when beam width $b\in\{2,3,4,5,6\}$.
  • ...and 2 more figures

Theorems & Definitions (17)

  • Definition 3.1
  • Theorem 3.2
  • Corollary 3.3
  • Lemma 3.4
  • Theorem 3.5
  • Theorem 3.6
  • Theorem 3.7
  • proof
  • Theorem 3.8
  • proof
  • ...and 7 more