Table of Contents
Fetching ...

Where-to-Unmask: Ground-Truth-Guided Unmasking Order Learning for Masked Diffusion Language Models

Hikaru Asano, Tadashi Kozuno, Kuniaki Saito, Yukino Baba

TL;DR

Masked Diffusion Language Models (MDLMs) generate text by iteratively unmasking tokens, blending two decisions: where-to-unmask and what-to-unmask. We introduce Gt-Margin, a ground-truth–driven score that ranks masked positions by the decisiveness of the ground-truth token, yielding an oracle easy-to-hard unmasking order, and show that this ordering substantially improves reasoning performance on logical benchmarks. Building on this, we train a supervised unmasking planner via learning-to-rank to imitate the oracle ordering and integrate it as a plug-and-play component during MDLM sampling, without altering the token prediction model. Empirically, early unmasking guidance is particularly impactful, with the planner providing additional gains on GSM8K and MATH and partial-plan decoding focusing plans on the high-uncertainty regime. Overall, the work delivers a practical, data-driven method to optimize where-to-unmask in MDLMs, enabling better reasoning while preserving decoding efficiency.

Abstract

Masked Diffusion Language Models (MDLMs) generate text by iteratively filling masked tokens, requiring two coupled decisions at each step: which positions to unmask (where-to-unmask) and which tokens to place (what-to-unmask). While standard MDLM training directly optimizes token prediction (what-to-unmask), inference-time unmasking orders (where-to-unmask) are typically determined by heuristic confidence measures or trained through reinforcement learning with costly on-policy rollouts. To address this, we introduce Gt-Margin, a position-wise score derived from ground-truth tokens, defined as the probability margin between the correct token and its strongest alternative. Gt-Margin yields an oracle unmasking order that prioritizes easier positions first under each partially masked state. We demonstrate that leveraging this oracle unmasking order significantly enhances final generation quality, particularly on logical reasoning benchmarks. Building on this insight, we train a supervised unmasking planner via learning-to-rank to imitate the oracle ordering from masked contexts. The resulting planner integrates into standard MDLM sampling to select where-to-unmask, improving reasoning accuracy without modifying the token prediction model.

Where-to-Unmask: Ground-Truth-Guided Unmasking Order Learning for Masked Diffusion Language Models

TL;DR

Masked Diffusion Language Models (MDLMs) generate text by iteratively unmasking tokens, blending two decisions: where-to-unmask and what-to-unmask. We introduce Gt-Margin, a ground-truth–driven score that ranks masked positions by the decisiveness of the ground-truth token, yielding an oracle easy-to-hard unmasking order, and show that this ordering substantially improves reasoning performance on logical benchmarks. Building on this, we train a supervised unmasking planner via learning-to-rank to imitate the oracle ordering and integrate it as a plug-and-play component during MDLM sampling, without altering the token prediction model. Empirically, early unmasking guidance is particularly impactful, with the planner providing additional gains on GSM8K and MATH and partial-plan decoding focusing plans on the high-uncertainty regime. Overall, the work delivers a practical, data-driven method to optimize where-to-unmask in MDLMs, enabling better reasoning while preserving decoding efficiency.

Abstract

Masked Diffusion Language Models (MDLMs) generate text by iteratively filling masked tokens, requiring two coupled decisions at each step: which positions to unmask (where-to-unmask) and which tokens to place (what-to-unmask). While standard MDLM training directly optimizes token prediction (what-to-unmask), inference-time unmasking orders (where-to-unmask) are typically determined by heuristic confidence measures or trained through reinforcement learning with costly on-policy rollouts. To address this, we introduce Gt-Margin, a position-wise score derived from ground-truth tokens, defined as the probability margin between the correct token and its strongest alternative. Gt-Margin yields an oracle unmasking order that prioritizes easier positions first under each partially masked state. We demonstrate that leveraging this oracle unmasking order significantly enhances final generation quality, particularly on logical reasoning benchmarks. Building on this insight, we train a supervised unmasking planner via learning-to-rank to imitate the oracle ordering from masked contexts. The resulting planner integrates into standard MDLM sampling to select where-to-unmask, improving reasoning accuracy without modifying the token prediction model.
Paper Structure (56 sections, 19 equations, 6 figures, 6 tables, 2 algorithms)

This paper contains 56 sections, 19 equations, 6 figures, 6 tables, 2 algorithms.

Figures (6)

  • Figure 1: MDLMs learn what-to-unmask but leave where-to-unmask implicit, and heuristic scores (e.g., Margin) can yield incorrect outputs. We define Gt-Margin using ground-truth tokens and train a planner to imitate the oracle ordering, improving reasoning accuracy without modifying the token model.
  • Figure 2: When is Gt-Margin important? We start from Gt-Margin ordering and replace it only within a single 10% unmasking-step range (0--10%, 10--20%, …, 90--100%) by either Margin or Random, keeping Gt-Margin elsewhere (i.e., 90% of steps still use Gt-Margin). Across both LLaDA and Dream, perturbations in early steps lead to markedly worse final accuracy on most datasets, while late-step perturbations are less harmful, highlighting the outsized impact of early ordering decisions.
  • Figure 3: How long do we need oracle-like ordering? We decode with Gt-Margin for the first $n\%$ of steps and switch to Margin for the remaining $100-n\%$. Using Gt-Margin for approximately the first half is sufficient to recover most gains, indicating that oracle guidance is primarily beneficial early in decoding.
  • Figure 4: Empirical unmasking-order heatmaps on GSM8K. Both axes are normalized (x: token position, y: unmasking step), and color indicates how frequently a relative position is selected at a relative step. Margin concentrates tightly along the diagonal (near left-to-right behavior), while Gt-Margin retains a diagonal trend but places more mass off-diagonal, indicating more adaptive jumps to contextually easy positions.
  • Figure 5: Overview of learning an oracle-guided unmasking ordering planner. (a) We run controlled decoding from a fully masked state and select the next position using Gt-Margin, yielding a generated completion $\hat{\mathbf{x}}_0$ and an oracle unmasking order $\mathbf{r}$. (b) We first sample a diffusion time $t$ to construct a partially masked sequence $\mathbf{x}_t$; we then form a reconstructed sequence $\mathbf{x}_0'$ by filling masked positions with the token model's argmax predictions and build the planner input as prompt + $\mathbf{x}_0'$ + $\mathbf{x}_t$. (c) The planner outputs priority scores $\mathbf{s}$ over masked positions via a MLP head and is trained with listwise ranking supervision to match $\mathbf{r}$.
  • ...and 1 more figures