Table of Contents
Fetching ...

Learn from Your Mistakes: Self-Correcting Masked Diffusion Models

Yair Schiff, Omer Belhasin, Roy Uziel, Guanghan Wang, Marianne Arriola, Gilad Turok, Michael Elad, Volodymyr Kuleshov

TL;DR

ProSeCo introduces Progressive Self-Correction to overcome error accumulation in Masked Diffusion Models by jointly training a single model to both unmask and correct its own outputs. The method interleaves corrective refinement steps with standard unmasking during generation, enabling iterative, whole-sequence refinement and enabling improved quality at higher parallelization. Empirical results across math, coding, molecules, and unconditional text demonstrate faster sampling (2–3x) with maintained or improved accuracy (up to ~1.3x on benchmarks) and favorable quality-efficiency trade-offs, with scalable inference-time compute to further boost performance. The work also shows enhanced guided sampling and maintains output diversity in unconditional generation, presenting a practical approach to extend the capabilities of discrete diffusion models while offering clear guidelines for budgeted correction during inference.

Abstract

Masked diffusion models (MDMs) have emerged as a promising alternative to autoregressive models, enabling parallel token generation while achieving competitive performance. Despite these advantages, MDMs face a fundamental limitation: once tokens are unmasked, they remain fixed, leading to error accumulation and ultimately degrading sample quality. We address this by proposing a framework that trains a model to perform both unmasking and correction. By reusing outputs from the MDM denoising network as inputs for corrector training, we train a model to recover from potential mistakes. During generation we apply additional corrective refinement steps between unmasking ones in order to change decoded tokens and improve outputs. We name our training and sampling method Progressive Self-Correction (ProSeCo) for its unique ability to iteratively refine an entire sequence, including already generated tokens. We conduct extensive experimental validation across multiple conditional and unconditional tasks, demonstrating that ProSeCo yields better quality-efficiency trade-offs (up to ~2-3x faster sampling) and enables inference-time compute scaling to further increase sample quality beyond standard MDMs (up to ~1.3x improvement on benchmarks).

Learn from Your Mistakes: Self-Correcting Masked Diffusion Models

TL;DR

ProSeCo introduces Progressive Self-Correction to overcome error accumulation in Masked Diffusion Models by jointly training a single model to both unmask and correct its own outputs. The method interleaves corrective refinement steps with standard unmasking during generation, enabling iterative, whole-sequence refinement and enabling improved quality at higher parallelization. Empirical results across math, coding, molecules, and unconditional text demonstrate faster sampling (2–3x) with maintained or improved accuracy (up to ~1.3x on benchmarks) and favorable quality-efficiency trade-offs, with scalable inference-time compute to further boost performance. The work also shows enhanced guided sampling and maintains output diversity in unconditional generation, presenting a practical approach to extend the capabilities of discrete diffusion models while offering clear guidelines for budgeted correction during inference.

Abstract

Masked diffusion models (MDMs) have emerged as a promising alternative to autoregressive models, enabling parallel token generation while achieving competitive performance. Despite these advantages, MDMs face a fundamental limitation: once tokens are unmasked, they remain fixed, leading to error accumulation and ultimately degrading sample quality. We address this by proposing a framework that trains a model to perform both unmasking and correction. By reusing outputs from the MDM denoising network as inputs for corrector training, we train a model to recover from potential mistakes. During generation we apply additional corrective refinement steps between unmasking ones in order to change decoded tokens and improve outputs. We name our training and sampling method Progressive Self-Correction (ProSeCo) for its unique ability to iteratively refine an entire sequence, including already generated tokens. We conduct extensive experimental validation across multiple conditional and unconditional tasks, demonstrating that ProSeCo yields better quality-efficiency trade-offs (up to ~2-3x faster sampling) and enables inference-time compute scaling to further increase sample quality beyond standard MDMs (up to ~1.3x improvement on benchmarks).
Paper Structure (60 sections, 5 equations, 11 figures, 5 tables, 4 algorithms)

This paper contains 60 sections, 5 equations, 11 figures, 5 tables, 4 algorithms.

Figures (11)

  • Figure 1: (Left) Overview of training $\mathpzc{ProSeCo}$: The original process trains the model to generate via unmasking. For every timestep in the masking process, we train the model to undo corruptions that can arise from sampling from the model's unmasking predictions, thereby training for self-correction. (Right) Using our method to supervised fine-tune (SFT) the 8B parameter LLaDA model nie2025large significantly outperforms SFT with vanilla masked diffusion modeling.
  • Figure 2: Demonstrating the benefits of self-correction (LLaDA baseline SFT vs. $\mathpzc{ProSeCo}$ SFT; block AR decoding with 4 tokens generated at each step). (Left) During parallel unmasking errors occur. These mistakes accumulate, and by the 3rd block of generated text, the sample has collapsed. (Middle)$\mathpzc{ProSeCo}$ can self-correct and recover from errors. After the first block, a short correction loop steers generation 'back on track.' (Right)$\mathpzc{ProSeCo}$'s ability to directly alter previously decoded tokens leads to a final high quality output. (Generated sequences are trimmed for illustrative purposes; see Appendix \ref{['appsubsec:exp_results_fig_details']} for full details).
  • Figure 3: Analyzing the quality-efficiency trade-off for $\mathpzc{ProSeCo}$. Standard MDMs (Baseline; gray dot) attain best performance when decoding a single token in every step. $\mathpzc{ProSeCo}$ models can vary number of corrector steps and attain comparable performance more efficiently with fewer unmasking steps (Ours: Fast; green star), achieve even better quality for modest increase in compute budget (Ours: Balanced; orange star), or maximize quality by scaling inference-time compute even further (Ours: Max; blue star).
  • Figure 4: Pareto frontier of parallel decoding and quality. When decoding in parallel (i.e., fewer unmasking steps on $x$-axis), quality deteriorates. Applying a modest number of corrector steps, allows $\mathpzc{ProSeCo}$ models to recover from these errors and extend this frontier.
  • Figure 5: $\mathpzc{ProSeCo}$ better navigates the novelty-property maximization Pareto frontier. Values correspond to number of novel samples (valid and unique molecules not present in the QM9 dataset; $x$-axis) and mean property value of novel samples ($y$-axis) for controlled generation using discrete classifier-free guidance schiff2024simple, with varying unmasking steps $T$ (line style) and guidance strength $\gamma$ (marker size). (Left) Maximizing the ring count property. (Right) Maximizing the drug likeness (QED) property.
  • ...and 6 more figures