Table of Contents
Fetching ...

Direct Multi-Token Decoding

Xuan Luo, Weizhi Wang, Xifeng Yan

TL;DR

This work tackles the inefficiency of autoregressive token-by-token decoding in large language models by introducing Direct Multi-Token Decoding (DMTD), which reuses late transformer layers to generate multiple tokens in fixed cycles without adding parameters or requiring post-verification. DMTD is trained end-to-end with cyclical masking to learn multi-token predictions within a single forward pass and uses cyclical refilling to maintain the KV cache across cycles during inference. Empirical results on a 36-layer Qwen3-4B model show up to a 2x speedup with minor performance loss for cycle lengths up to four tokens, with stronger gains observed for larger models and more data. The method scales predictably with training data and model size, suggesting that larger datasets and bigger architectures could yield even greater speedups, making DMTD a compelling direction for accelerating LLM inference without extra components or post-processing. Overall, DMTD highlights a practical, memory-bound acceleration pathway by leveraging inherent layer specialization in decoder-only transformers.

Abstract

Decoder-only transformers have become the standard architecture for large language models (LLMs) due to their strong performance. Recent studies suggest that, in pre-trained LLMs, early, middle, and late layers may serve distinct roles: Early layers focus on understanding the input context, middle layers handle task-specific processing, and late layers convert abstract representations into output tokens. We hypothesize that once representations have been processed by the early and middle layers, the resulting hidden states may encapsulate sufficient information to support the generation of multiple tokens using only the late layers, eliminating the need to repeatedly traverse the early and middle layers. We refer to this inference paradigm as Direct Multi-Token Decoding (DMTD). Unlike speculative decoding, our method introduces no additional parameters, auxiliary routines, or post-generation verification. Despite being trained on a limited dataset, a fine-tuned DMTD Qwen3-4B model has already demonstrated promising results, achieving up to a 2x speedup with only minor performance loss. Moreover, as shown in our scaling analysis, its performance is expected to further improve with larger training datasets.

Direct Multi-Token Decoding

TL;DR

This work tackles the inefficiency of autoregressive token-by-token decoding in large language models by introducing Direct Multi-Token Decoding (DMTD), which reuses late transformer layers to generate multiple tokens in fixed cycles without adding parameters or requiring post-verification. DMTD is trained end-to-end with cyclical masking to learn multi-token predictions within a single forward pass and uses cyclical refilling to maintain the KV cache across cycles during inference. Empirical results on a 36-layer Qwen3-4B model show up to a 2x speedup with minor performance loss for cycle lengths up to four tokens, with stronger gains observed for larger models and more data. The method scales predictably with training data and model size, suggesting that larger datasets and bigger architectures could yield even greater speedups, making DMTD a compelling direction for accelerating LLM inference without extra components or post-processing. Overall, DMTD highlights a practical, memory-bound acceleration pathway by leveraging inherent layer specialization in decoder-only transformers.

Abstract

Decoder-only transformers have become the standard architecture for large language models (LLMs) due to their strong performance. Recent studies suggest that, in pre-trained LLMs, early, middle, and late layers may serve distinct roles: Early layers focus on understanding the input context, middle layers handle task-specific processing, and late layers convert abstract representations into output tokens. We hypothesize that once representations have been processed by the early and middle layers, the resulting hidden states may encapsulate sufficient information to support the generation of multiple tokens using only the late layers, eliminating the need to repeatedly traverse the early and middle layers. We refer to this inference paradigm as Direct Multi-Token Decoding (DMTD). Unlike speculative decoding, our method introduces no additional parameters, auxiliary routines, or post-generation verification. Despite being trained on a limited dataset, a fine-tuned DMTD Qwen3-4B model has already demonstrated promising results, achieving up to a 2x speedup with only minor performance loss. Moreover, as shown in our scaling analysis, its performance is expected to further improve with larger training datasets.

Paper Structure

This paper contains 17 sections, 7 equations, 5 figures, 5 tables.

Figures (5)

  • Figure 1: Vanilla next token prediction vs. Direct Multi-Token Decoding.
  • Figure 2: DMTD training pipeline with a cycle length of 3. The method requires no additional parameters and uses a single forward pass with masking to enable multi-token prediction training.
  • Figure 3: Cyclical refilling for multi-token decoding with a cycle length of 3. There are three blocks within each column, representing the early, middle, and late layers. Blocks with the same color are computed in the same forward pass. The numbers on the blocks represent the index of the forward pass.
  • Figure 4: Speedup comparison.
  • Figure 5: Scaling law of the proposed Direct Multi-token Decoding. The x-axis represents the number of training tokens (in billions) on a logarithmic scale, while the y-axis shows the cross-entropy loss.