Where-to-Unmask: Ground-Truth-Guided Unmasking Order Learning for Masked Diffusion Language Models
Hikaru Asano, Tadashi Kozuno, Kuniaki Saito, Yukino Baba
TL;DR
Masked Diffusion Language Models (MDLMs) generate text by iteratively unmasking tokens, blending two decisions: where-to-unmask and what-to-unmask. We introduce Gt-Margin, a ground-truth–driven score that ranks masked positions by the decisiveness of the ground-truth token, yielding an oracle easy-to-hard unmasking order, and show that this ordering substantially improves reasoning performance on logical benchmarks. Building on this, we train a supervised unmasking planner via learning-to-rank to imitate the oracle ordering and integrate it as a plug-and-play component during MDLM sampling, without altering the token prediction model. Empirically, early unmasking guidance is particularly impactful, with the planner providing additional gains on GSM8K and MATH and partial-plan decoding focusing plans on the high-uncertainty regime. Overall, the work delivers a practical, data-driven method to optimize where-to-unmask in MDLMs, enabling better reasoning while preserving decoding efficiency.
Abstract
Masked Diffusion Language Models (MDLMs) generate text by iteratively filling masked tokens, requiring two coupled decisions at each step: which positions to unmask (where-to-unmask) and which tokens to place (what-to-unmask). While standard MDLM training directly optimizes token prediction (what-to-unmask), inference-time unmasking orders (where-to-unmask) are typically determined by heuristic confidence measures or trained through reinforcement learning with costly on-policy rollouts. To address this, we introduce Gt-Margin, a position-wise score derived from ground-truth tokens, defined as the probability margin between the correct token and its strongest alternative. Gt-Margin yields an oracle unmasking order that prioritizes easier positions first under each partially masked state. We demonstrate that leveraging this oracle unmasking order significantly enhances final generation quality, particularly on logical reasoning benchmarks. Building on this insight, we train a supervised unmasking planner via learning-to-rank to imitate the oracle ordering from masked contexts. The resulting planner integrates into standard MDLM sampling to select where-to-unmask, improving reasoning accuracy without modifying the token prediction model.
