Table of Contents
Fetching ...

Diffusion-State Policy Optimization for Masked Diffusion Language Models

Daisuke Oba, Hiroki Furuta, Naoaki Okazaki

TL;DR

This paper tackles the problem of coarse credit assignment in terminal-reward policy optimization for masked diffusion language models (MDLMs). It introduces Diffusion-State Policy Optimization (DiSPO), a plug-in layer that performs state-wise, same-state branching at intermediate denoising steps by resampling fillings from cached logits and updating only newly filled tokens, thereby providing finer-grained credit assignment without extra diffusion rollouts. The authors formalize a fixed-state objective, derive a policy-gradient estimator, and show that DiSPO can be combined with terminal-feedback objectives to form a mixed objective with coherent gradients. Empirically, DiSPO yields consistent accuracy gains on math and planning benchmarks (e.g., Sudoku, Countdown, GSM8K, MATH500) for LLaDA-8B-Instruct under matched rollout compute and updates, with variance-reduction benefits from token-local updates and same-state averaging. The results demonstrate that intermediate-state credit assignment can meaningfully improve MDLM reasoning and planning tasks, suggesting broader applicability to diffusion-based NLP models.

Abstract

Masked diffusion language models generate by iteratively filling masked tokens over multiple denoising steps, so learning only from a terminal reward on the final completion yields coarse credit assignment over intermediate decisions. We propose DiSPO (Diffusion-State Policy Optimization), a plug-in credit-assignment layer that directly optimizes intermediate filling decisions. At selected intermediate masked states, DiSPO branches by resampling fillings for the currently masked positions from rollout-cached logits, scores the resulting completions, and updates only the newly filled tokens -- without additional multi-step diffusion rollouts. We formalize a fixed-state objective for branched completions and derive a policy-gradient estimator that can be combined with terminal-feedback policy optimization using the same rollouts. On LLaDA-8B-Instruct, DiSPO consistently improves over the terminal-feedback diffu-GRPO baseline on math and planning benchmarks under matched rollout compute and optimizer steps. Our code will be available at https://daioba.github.io/dispo .

Diffusion-State Policy Optimization for Masked Diffusion Language Models

TL;DR

This paper tackles the problem of coarse credit assignment in terminal-reward policy optimization for masked diffusion language models (MDLMs). It introduces Diffusion-State Policy Optimization (DiSPO), a plug-in layer that performs state-wise, same-state branching at intermediate denoising steps by resampling fillings from cached logits and updating only newly filled tokens, thereby providing finer-grained credit assignment without extra diffusion rollouts. The authors formalize a fixed-state objective, derive a policy-gradient estimator, and show that DiSPO can be combined with terminal-feedback objectives to form a mixed objective with coherent gradients. Empirically, DiSPO yields consistent accuracy gains on math and planning benchmarks (e.g., Sudoku, Countdown, GSM8K, MATH500) for LLaDA-8B-Instruct under matched rollout compute and updates, with variance-reduction benefits from token-local updates and same-state averaging. The results demonstrate that intermediate-state credit assignment can meaningfully improve MDLM reasoning and planning tasks, suggesting broader applicability to diffusion-based NLP models.

Abstract

Masked diffusion language models generate by iteratively filling masked tokens over multiple denoising steps, so learning only from a terminal reward on the final completion yields coarse credit assignment over intermediate decisions. We propose DiSPO (Diffusion-State Policy Optimization), a plug-in credit-assignment layer that directly optimizes intermediate filling decisions. At selected intermediate masked states, DiSPO branches by resampling fillings for the currently masked positions from rollout-cached logits, scores the resulting completions, and updates only the newly filled tokens -- without additional multi-step diffusion rollouts. We formalize a fixed-state objective for branched completions and derive a policy-gradient estimator that can be combined with terminal-feedback policy optimization using the same rollouts. On LLaDA-8B-Instruct, DiSPO consistently improves over the terminal-feedback diffu-GRPO baseline on math and planning benchmarks under matched rollout compute and optimizer steps. Our code will be available at https://daioba.github.io/dispo .
Paper Structure (35 sections, 4 theorems, 55 equations, 6 figures, 6 tables, 1 algorithm)

This paper contains 35 sections, 4 theorems, 55 equations, 6 figures, 6 tables, 1 algorithm.

Key Result

Theorem 4.1

Assume (i) $d_t(q)$ is independent of $\theta$, and (ii) $\tilde{\pi}_\theta(\cdot\mid s)$ is differentiable and normalized for each $s$. Using the within-group mean baseline in Eq. eq:group-adv, in the unclipped likelihood-ratio setting the expected gradient of the step-wise loss is aligned with th

Figures (6)

  • Figure 1: Conceptual overview.Top: Terminal-feedback GRPO treats the denoising trajectory as one decision. Bottom: DiSPO is a plug-in step that branches at intermediate states (resample $Z$ fillings from cached logits), scores them with the same reward, and backpropagates gradients only through the filled tokens.
  • Figure 2: Reward curves. Terminal reward curves (top) and step reward curves (bottom) on LLaDA-8B-Instruct during policy optimization. Across tasks, DiSPO reaches higher terminal rewards earlier and maintains them over training. Step rewards exhibit relatively smaller magnitudes but follow trends as terminal rewards, indicating their role as a complementary training signal.
  • Figure 3: Variance reduction of the step-wise gradient estimator on Sudoku.Left: Updating only action tokens (vs. all tokens) reduces variance at $Z{=}2$ (Prop. \ref{['prop:var']}). Right: Increasing $Z$ from $Z{=}2$ reduces variance with action-only updates (Prop. \ref{['prop:compute']}). Error bars show paired 95% bootstrap CIs.
  • Figure 4: Example of error completion in Sudoku. Red/blue indicate incorrect/correct fillings; superscripts denote the fill order. diffu-GRPO violates constraints earlier than DiSPO.
  • Figure 5: First-violation time on Sudoku. The first-violation time is the earliest denoising step at which the filled cells violate Sudoku constraints. DiSPO shifts violations to later steps than diffu-GRPO, indicating fewer premature commitments.
  • ...and 1 more figures

Theorems & Definitions (4)

  • Theorem 4.1: Principled step-level policy gradient
  • Theorem 4.2: Overall policy gradient
  • Proposition 4.3: Variance reduction by partial updates
  • Proposition 4.4: Variance vs. number of drafts