Table of Contents
Fetching ...

Learning to Parallel: Accelerating Diffusion Large Language Models via Learnable Parallel Decoding

Wenrui Bao, Zhiben Chen, Dan Xu, Yuzhang Shang

TL;DR

This work tackles the bottleneck of autoregressive decoding in large language models by leveraging diffusion-based parallel decoding. It introduces Learn2PD, a lightweight post-training filter that predicts token stability to approximate an oracle that unmasks tokens only when correctly predicted, plus End-of-Text Prediction (EoTP) to terminate decoding early. The approach delivers substantial inference speedups with minimal accuracy loss, and is orthogonal to KV caching, enabling even larger gains (up to $57.51\times$ in combination with Dual Cache). Experiments on LLaDA-8B-Instruct across GSM8K, Math, HumanEval, and MBPP demonstrate speedups from around $3$–$4\times$ (256 tokens) to $22.58\times$ (1024 tokens), with modest reductions in throughput when caching mechanisms are employed.

Abstract

Autoregressive decoding in large language models (LLMs) requires $\mathcal{O}(n)$ sequential steps for $n$ tokens, fundamentally limiting inference throughput. Recent diffusion-based LLMs (dLLMs) enable parallel token generation through iterative denoising. However, current parallel decoding strategies rely on fixed, input-agnostic heuristics (e.g., confidence thresholds), which fail to adapt to input-specific characteristics, resulting in suboptimal speed-quality trade-offs across diverse NLP tasks. In this work, we explore a more flexible and dynamic approach to parallel decoding. We propose Learning to Parallel Decode (Learn2PD), a framework that trains a lightweight and adaptive filter model to predict, for each token position, whether the current prediction matches the final output. This learned filter approximates an oracle parallel decoding strategy that unmasks tokens only when correctly predicted. Importantly, the filter model is learned in a post-training manner, requiring only a small amount of computation to optimize it (minute-level GPU time). Additionally, we introduce End-of-Text Prediction (EoTP) to detect decoding completion at the end of sequence, avoiding redundant decoding of padding tokens. Experiments on the LLaDA benchmark demonstrate that our method achieves up to 22.58$\times$ speedup without any performance drop, and up to 57.51$\times$ when combined with KV-Cache.

Learning to Parallel: Accelerating Diffusion Large Language Models via Learnable Parallel Decoding

TL;DR

This work tackles the bottleneck of autoregressive decoding in large language models by leveraging diffusion-based parallel decoding. It introduces Learn2PD, a lightweight post-training filter that predicts token stability to approximate an oracle that unmasks tokens only when correctly predicted, plus End-of-Text Prediction (EoTP) to terminate decoding early. The approach delivers substantial inference speedups with minimal accuracy loss, and is orthogonal to KV caching, enabling even larger gains (up to in combination with Dual Cache). Experiments on LLaDA-8B-Instruct across GSM8K, Math, HumanEval, and MBPP demonstrate speedups from around (256 tokens) to (1024 tokens), with modest reductions in throughput when caching mechanisms are employed.

Abstract

Autoregressive decoding in large language models (LLMs) requires sequential steps for tokens, fundamentally limiting inference throughput. Recent diffusion-based LLMs (dLLMs) enable parallel token generation through iterative denoising. However, current parallel decoding strategies rely on fixed, input-agnostic heuristics (e.g., confidence thresholds), which fail to adapt to input-specific characteristics, resulting in suboptimal speed-quality trade-offs across diverse NLP tasks. In this work, we explore a more flexible and dynamic approach to parallel decoding. We propose Learning to Parallel Decode (Learn2PD), a framework that trains a lightweight and adaptive filter model to predict, for each token position, whether the current prediction matches the final output. This learned filter approximates an oracle parallel decoding strategy that unmasks tokens only when correctly predicted. Importantly, the filter model is learned in a post-training manner, requiring only a small amount of computation to optimize it (minute-level GPU time). Additionally, we introduce End-of-Text Prediction (EoTP) to detect decoding completion at the end of sequence, avoiding redundant decoding of padding tokens. Experiments on the LLaDA benchmark demonstrate that our method achieves up to 22.58 speedup without any performance drop, and up to 57.51 when combined with KV-Cache.

Paper Structure

This paper contains 30 sections, 6 equations, 7 figures, 5 tables, 3 algorithms.

Figures (7)

  • Figure 1: Effectiveness of our proposed approaches. We report the throughput and accuracy on GSM8K (5-shot, Generation Length=1024) with LLaDA and our proposed methods under four settings: (1) vanilla decoding, (2) Learn2PD policy, (3) Learn2PD and EoTP mechanism, (4) Learn2PD and EoTP integrated by KV Cache. Our proposed methods, Learn2PD and EoTP, yield a $22.58\times$ speedup over the vanilla baseline while simultaneously preserving the original accuracy. Integration with KV Cache achieves a further improvement in throughput to 16.37 tokens/sec (a 57.51× speedup), with only a minimal loss in accuracy.
  • Figure 2: The unnecessary and repetitive decoding steps in different datasets: GSM8K and HumanEval. (a) Distributions of gaps. These two histograms show the distribution of step gaps for each token between the decoding step and the step with the first correct prediction. (b) Samples of gaps. The red line means the first correct prediction step, and the blue line means the actual decoding step.
  • Figure 3: A Conceptual Overview of pipeline and method. (a) Extremely Greedy Parallel (EGP). This strategy compares the predicted tokens with the reference answer and only remasks the tokens that do not match in these comparisons. (b) Learning to Parallel Decoding (Learn2PD). During the inference process, after the model generates predictions and confidences for each token, the confidence of each token is fed into a filter model $f_\theta$ to determine which tokens need to be remasked. This determination then guides the subsequent remasking procedure.
  • Figure 4: Distribution of decoding steps per block with Extremely Greedy Parallel (EGP) strategy. Histograms illustrate the number of decoding steps performed in each block when using our strategy with LLaDA-8B-Instruct on GSM8K based on 100 samples.
  • Figure 5: Schematic of the End-of-Text Prediction Policy. During the inference process, upon detection of an [EoT] token in a decoded block, all subsequent tokens are discarded.
  • ...and 2 more figures