Learning to Parallel: Accelerating Diffusion Large Language Models via Learnable Parallel Decoding
Wenrui Bao, Zhiben Chen, Dan Xu, Yuzhang Shang
TL;DR
This work tackles the bottleneck of autoregressive decoding in large language models by leveraging diffusion-based parallel decoding. It introduces Learn2PD, a lightweight post-training filter that predicts token stability to approximate an oracle that unmasks tokens only when correctly predicted, plus End-of-Text Prediction (EoTP) to terminate decoding early. The approach delivers substantial inference speedups with minimal accuracy loss, and is orthogonal to KV caching, enabling even larger gains (up to $57.51\times$ in combination with Dual Cache). Experiments on LLaDA-8B-Instruct across GSM8K, Math, HumanEval, and MBPP demonstrate speedups from around $3$–$4\times$ (256 tokens) to $22.58\times$ (1024 tokens), with modest reductions in throughput when caching mechanisms are employed.
Abstract
Autoregressive decoding in large language models (LLMs) requires $\mathcal{O}(n)$ sequential steps for $n$ tokens, fundamentally limiting inference throughput. Recent diffusion-based LLMs (dLLMs) enable parallel token generation through iterative denoising. However, current parallel decoding strategies rely on fixed, input-agnostic heuristics (e.g., confidence thresholds), which fail to adapt to input-specific characteristics, resulting in suboptimal speed-quality trade-offs across diverse NLP tasks. In this work, we explore a more flexible and dynamic approach to parallel decoding. We propose Learning to Parallel Decode (Learn2PD), a framework that trains a lightweight and adaptive filter model to predict, for each token position, whether the current prediction matches the final output. This learned filter approximates an oracle parallel decoding strategy that unmasks tokens only when correctly predicted. Importantly, the filter model is learned in a post-training manner, requiring only a small amount of computation to optimize it (minute-level GPU time). Additionally, we introduce End-of-Text Prediction (EoTP) to detect decoding completion at the end of sequence, avoiding redundant decoding of padding tokens. Experiments on the LLaDA benchmark demonstrate that our method achieves up to 22.58$\times$ speedup without any performance drop, and up to 57.51$\times$ when combined with KV-Cache.
