Table of Contents
Fetching ...

Improving Discrete Diffusion Unmasking Policies Beyond Explicit Reference Policies

Chunsan Hong, Seonho An, Min-Soo Kim, Jong Chul Ye

TL;DR

This work addresses the sensitivity of discrete diffusion models to unmasking order in language generation. It reframes denoising as a KL-regularized MDP with an explicit reference policy and learns a parametric unmasking policy via GRPO, providing theoretical guarantees of policy improvement and tighter alignment to the data distribution than heuristic references. The authors derive tractable surrogates like $\mathcal{L}_{\rm output}$, $\mathcal{L}_{\rm token}$, and $\mathcal{L}_{\rm KL}$, enabling memory-efficient training of a compact policy model that augments a frozen MDM. Empirically, the learned policy consistently outperforms strong baselines across Sudoku, Zebra, GSM8K, and Math500 benchmarks, with notable gains on logic puzzles and competitive results on reasoning tasks, illustrating the practical impact of optimized unmasking strategies for discrete diffusion models.

Abstract

Masked diffusion models (MDMs) have recently emerged as a novel framework for language modeling. MDMs generate sentences by iteratively denoising masked sequences, filling in [MASK] tokens step by step. Although MDMs support any-order sampling, performance is highly sensitive to the choice of which position to unmask next. Prior work typically relies on rule-based schedules (e.g., max-confidence, max-margin), which provide ad hoc improvements. In contrast, we replace these heuristics with a learned scheduler. Specifically, we cast denoising as a KL-regularized Markov decision process (MDP) with an explicit reference policy and optimize a regularized objective that admits policy improvement and convergence guarantees under standard assumptions. We prove that the optimized policy under this framework generates samples that more closely match the data distribution than heuristic schedules. Empirically, across four benchmarks, our learned policy consistently outperforms max-confidence: for example, on SUDOKU, where unmasking order is critical, it yields a 20.1% gain over random and a 11.2% gain over max-confidence.

Improving Discrete Diffusion Unmasking Policies Beyond Explicit Reference Policies

TL;DR

This work addresses the sensitivity of discrete diffusion models to unmasking order in language generation. It reframes denoising as a KL-regularized MDP with an explicit reference policy and learns a parametric unmasking policy via GRPO, providing theoretical guarantees of policy improvement and tighter alignment to the data distribution than heuristic references. The authors derive tractable surrogates like , , and , enabling memory-efficient training of a compact policy model that augments a frozen MDM. Empirically, the learned policy consistently outperforms strong baselines across Sudoku, Zebra, GSM8K, and Math500 benchmarks, with notable gains on logic puzzles and competitive results on reasoning tasks, illustrating the practical impact of optimized unmasking strategies for discrete diffusion models.

Abstract

Masked diffusion models (MDMs) have recently emerged as a novel framework for language modeling. MDMs generate sentences by iteratively denoising masked sequences, filling in [MASK] tokens step by step. Although MDMs support any-order sampling, performance is highly sensitive to the choice of which position to unmask next. Prior work typically relies on rule-based schedules (e.g., max-confidence, max-margin), which provide ad hoc improvements. In contrast, we replace these heuristics with a learned scheduler. Specifically, we cast denoising as a KL-regularized Markov decision process (MDP) with an explicit reference policy and optimize a regularized objective that admits policy improvement and convergence guarantees under standard assumptions. We prove that the optimized policy under this framework generates samples that more closely match the data distribution than heuristic schedules. Empirically, across four benchmarks, our learned policy consistently outperforms max-confidence: for example, on SUDOKU, where unmasking order is critical, it yields a 20.1% gain over random and a 11.2% gain over max-confidence.

Paper Structure

This paper contains 28 sections, 9 theorems, 64 equations, 5 figures, 3 tables, 2 algorithms.

Key Result

Theorem 1

Assume $1>r_{g_{\rm{ref}}}>0$, define for $\beta>0$: Then, for the local optimizer $\phi_n=\max_{\phi}\;\rm{Eq.~(eq:theoretical_loss)}$ where $\phi_{\rm{old}}=\phi_{n-1}$, the probability of success satisfies the following fixed point iteration i.e we have almost surely for all ${\mathbf{q}}$ for $n\ge1$, and $r_{g_{\phi_{0}}}({\mathbf{q}})=r_{g_{\rm{ref}}}({\mathbf{q}})$. Also, let $r_{g^*}$ be

Figures (5)

  • Figure 1: Pass@N on GSM8K. The dashed red line is the single-trajectory accuracy of max-confidence $g_{\rm conf}$.
  • Figure 2: The structure of our unmasking policy model. The model is composed of a single Transformer layer and a 3-layer MLP. In inference time, the model proceeds as follows: 1) Given an input sentence, we run the base MDMs transformer and extract a feature. 2) This feature branches into two paths: first, the base MDMs continue to produce a prediction of token distribution, and second, fed to the policy model’s transformer layer. 3) We concatenate the extracted feature with the base MDM’s Top-K probabilities. 4) The concatenated feature is then processed by a 3-layer MLP to yield a policy over unmasking positions. The shared structure with frozen MDM supports memory-efficient training (e.g., LLaDA: 8B, policy model: 134M), where we only update the unmasking policy model. The details and memory-efficient training algorithm are provided in Appendix \ref{['appendix:algorithms']}
  • Figure 3: Training accuracy on Sudoku and GSM8K under $\mathcal{L}_{\rm UPO}(g_\phi, g_{\rm ref}, D)$ for different choices of $g_{\rm ref}$. Reference realizations are summarized in Table \ref{['tab:policy_realizations']}. “Random initialization” denotes training a randomly initialized model with $\mathcal{L}_{\rm UPO}(g_\phi, \varnothing, \varnothing)$.
  • Figure 4: Mean and standard deviation of group reward during training under $\mathcal{L}_{\rm UPO}$. For SUDOKU, we compare pretrained policy trained with/without CE to $g_{\rm conf}$, and for GSM8K, we compare randomly initialized $g_{\phi_{\rm Top\text{-}K}}$ trained with/without KL to $g_{\rm Top\text{-}K}$.
  • Figure 5: The structure of our unmasking policy model. The model is composed of a single Transformer layer and a 3-layer MLP. In inference time, the model proceeds as follows: 1) Given an input sentence, we run the base MDMs transformer and extract a feature. 2) This feature branches into two paths: first, the base MDMs continue to produce a prediction of token distribution, and second, fed to the policy model’s transformer layer. 3) We concatenate the extracted feature with the base MDM’s Top-K probabilities. 4) The concatenated feature is then processed by a 3-layer MLP to yield a policy over unmasking positions.

Theorems & Definitions (17)

  • Example 1
  • Theorem 1: Restatement of GRPO convergence theorem mroueh2025reinforcement
  • Theorem 2: Reference-KL Tightening for MDM Policy Improvement
  • proof : Proof sketch.
  • Proposition 1: Output--Token Level Gradient Alignment (informal)
  • proof : Proof sketch.
  • Theorem 2: Reference-KL Tightening for MDM Policy Improvement
  • proof
  • Lemma 1: GRPO Policy Dynamic mroueh2025reinforcement
  • Lemma 2: Policy gradient theorem in sparse-reward, finite-horizon, episodic, and non-stationary MDP
  • ...and 7 more