Table of Contents
Fetching ...

Dependency-Guided Parallel Decoding in Discrete Diffusion Language Models

Liran Ringel, Ameen Ali, Yaniv Romano

Abstract

Discrete diffusion language models (dLLMs) accelerate text generation by unmasking multiple tokens in parallel. However, parallel decoding introduces a distributional mismatch: it approximates the joint conditional using a fully factorized product of per-token marginals, which degrades output quality when selected tokens are strongly dependent. We propose DEMASK (DEpendency-guided unMASKing), a lightweight dependency predictor that attaches to the final hidden states of a dLLM. In a single forward pass, it estimates pairwise conditional influences between masked positions. Using these predictions, a greedy selection algorithm identifies positions with bounded cumulative dependency for simultaneous unmasking. Under a sub-additivity assumption, we prove this bounds the total variation distance between our parallel sampling and the model's joint. Empirically, DEMASK achieves 1.7-2.2$\times$ speedup on Dream-7B while matching or improving accuracy compared to confidence-based and KL-based baselines.

Dependency-Guided Parallel Decoding in Discrete Diffusion Language Models

Abstract

Discrete diffusion language models (dLLMs) accelerate text generation by unmasking multiple tokens in parallel. However, parallel decoding introduces a distributional mismatch: it approximates the joint conditional using a fully factorized product of per-token marginals, which degrades output quality when selected tokens are strongly dependent. We propose DEMASK (DEpendency-guided unMASKing), a lightweight dependency predictor that attaches to the final hidden states of a dLLM. In a single forward pass, it estimates pairwise conditional influences between masked positions. Using these predictions, a greedy selection algorithm identifies positions with bounded cumulative dependency for simultaneous unmasking. Under a sub-additivity assumption, we prove this bounds the total variation distance between our parallel sampling and the model's joint. Empirically, DEMASK achieves 1.7-2.2 speedup on Dream-7B while matching or improving accuracy compared to confidence-based and KL-based baselines.

Paper Structure

This paper contains 37 sections, 1 theorem, 20 equations, 4 figures, 2 tables, 2 algorithms.

Key Result

Theorem 4.2

Suppose that Assumption ass:subadditivity holds and that Algorithm alg:selection is implemented with the true matrix $\mathbf{D}$ and error threshold $\tau$. The selected subset $S$ returned by Algorithm alg:selection satisfies $\textup{TV}(P_\theta(Y_S \mid X, Y_U), Q_\theta(Y_S \mid X, Y_U)) \le \

Figures (4)

  • Figure 1: Accuracy vs. mean diffusion steps (forward passes) on GSM8K for DEMASK and KLASS. Each point represents a hyperparameter configuration; darker points indicate higher confidence thresholds. The Pareto frontier (dashed) shows DEMASK dominates across the efficiency-accuracy trade-off. Dream (1 TPF) denotes 1 token per forward pass with entropy selection.
  • Figure 2: Overview of DEMASK. (A) A lightweight dependency predictor attaches to the dLLM backbone and estimates pairwise dependencies $\hat{\mathbf{D}}$ from hidden states in a single forward pass. (B) Greedy subset selection identifies positions with bounded cumulative dependency for parallel unmasking. (C) The iterative decoding cycle: each step performs a forward pass, selects positions, and samples them in parallel until all tokens are unmasked.
  • Figure 3: Dependency predictor architecture. Hidden states $\mathbf{H}$ from the frozen backbone are projected via learned $\mathbf{W}_Q, \mathbf{W}_K$, then combined via scaled dot-product and sigmoid to predict the pairwise dependency matrix $\hat{\mathbf{D}}$.
  • Figure 4: Empirical CDF of the slack $\mathrm{RHS}_i - \mathrm{LHS}_i$ stratified by subset size $|S|$, evaluated on Tulu 3 SFT Mixture with Dream-7B. The $|S|=1$ curve (vertical line at zero) is trivially satisfied. For $|S| \geq 2$, positive slack indicates the sub-additivity bound holds.

Theorems & Definitions (2)

  • Theorem 4.2: Correctness of Algorithm \ref{['alg:selection']}
  • proof