Table of Contents
Fetching ...

Fine-Tuning Masked Diffusion for Provable Self-Correction

Jaeyeon Kim, Seunggeun Kim, Taekyun Lee, David Z. Pan, Hyeji Kim, Sham Kakade, Sitan Chen

TL;DR

This work tackles the limited self-correction ability of Masked Diffusion Models by introducing PRISM, a plug-in, inference-time remasking framework with a provable self-correction loss. PRISM adds a lightweight adapter to pretrained MDMs to learn per-token quality and jointly model unmasking posteriors, enabling tokens to be remasked and revised during inference without retraining the entire model. The approach yields consistent improvements across domains—Sudoku, unconditional text with a 170M MDM, and code with LLaDA-8B—while maintaining efficiency and compatibility with existing architectures. The results demonstrate that PRISM learns calibrated per-token quality signals and enables effective self-correction with modest fine-tuning, highlighting a scalable path toward robust discrete sequence generation.

Abstract

A natural desideratum for generative models is self-correction--detecting and revising low-quality tokens at inference. While Masked Diffusion Models (MDMs) have emerged as a promising approach for generative modeling in discrete spaces, their capacity for self-correction remains poorly understood. Prior attempts to incorporate self-correction into MDMs either require overhauling MDM architectures/training or rely on imprecise proxies for token quality, limiting their applicability. Motivated by this, we introduce PRISM--Plug-in Remasking for Inference-time Self-correction of Masked Diffusions--a lightweight, model-agnostic approach that applies to any pretrained MDM. Theoretically, PRISM defines a self-correction loss that provably learns per-token quality scores, without RL or a verifier. These quality scores are computed in the same forward pass with MDM and used to detect low-quality tokens. Empirically, PRISM advances MDM inference across domains and scales: Sudoku; unconditional text (170M); and code with LLaDA (8B).

Fine-Tuning Masked Diffusion for Provable Self-Correction

TL;DR

This work tackles the limited self-correction ability of Masked Diffusion Models by introducing PRISM, a plug-in, inference-time remasking framework with a provable self-correction loss. PRISM adds a lightweight adapter to pretrained MDMs to learn per-token quality and jointly model unmasking posteriors, enabling tokens to be remasked and revised during inference without retraining the entire model. The approach yields consistent improvements across domains—Sudoku, unconditional text with a 170M MDM, and code with LLaDA-8B—while maintaining efficiency and compatibility with existing architectures. The results demonstrate that PRISM learns calibrated per-token quality signals and enables effective self-correction with modest fine-tuning, highlighting a scalable path toward robust discrete sequence generation.

Abstract

A natural desideratum for generative models is self-correction--detecting and revising low-quality tokens at inference. While Masked Diffusion Models (MDMs) have emerged as a promising approach for generative modeling in discrete spaces, their capacity for self-correction remains poorly understood. Prior attempts to incorporate self-correction into MDMs either require overhauling MDM architectures/training or rely on imprecise proxies for token quality, limiting their applicability. Motivated by this, we introduce PRISM--Plug-in Remasking for Inference-time Self-correction of Masked Diffusions--a lightweight, model-agnostic approach that applies to any pretrained MDM. Theoretically, PRISM defines a self-correction loss that provably learns per-token quality scores, without RL or a verifier. These quality scores are computed in the same forward pass with MDM and used to detect low-quality tokens. Empirically, PRISM advances MDM inference across domains and scales: Sudoku; unconditional text (170M); and code with LLaDA (8B).

Paper Structure

This paper contains 30 sections, 2 theorems, 26 equations, 7 figures, 4 tables, 4 algorithms.

Key Result

Proposition 1

The per-token quality uniquely minimizes the PRISM loss $\mathcal{L}(\phi)$ equation eq:loss_prism:

Figures (7)

  • Figure 1: PRISM overview. MDMs learn an unmasking posterior to unmask tokens, after which they remain fixed (Left). PRISM fine-tuning introduces per-token quality, which is used to detect incorrect tokens and remask them. At inference, the fine-tuned MDM jointly computes unmasking posterior and per-token quality, respectively performing unmasking and remasking (Right).
  • Figure 2: PRISM training pipeline. Fine-tuning samples are built in two steps; (a) masking process to obtain $({\mathbf{x}},{\mathbf{z}})$ and (b) unmask a few indices in ${\mathbf{z}}$ to get ${\mathbf{y}}$ with a pretrained MDM $f_\theta$ (Left). We add a lightweight adapter to a pretrained MDM so that the resulting model computes the unmasking posterior ($f_\theta$) and per-token quality ($g_\theta$) (Right).
  • Figure 3: Left: PRISM achieves higher success accuracy on Sudoku puzzles than baselines. (red line) Right: PRISM detects the incorrect cells by assigning low per-token quality (red-colored cells).
  • Figure 4: Unconditional text generation performance.Metrics are evaluated at $\text{NFE} \in \{128, 256, 512, 1024, 2048, 4096\}$; black dashed lines denote validation set references. PRISM (red) outperforms baselines (ReMDM: green, ReMDM-conf: blue, Vanilla MDM: gray), particularly at lower sampling steps ($N<1024$). Detailed numerical results are reported in Table \ref{['tab:remdm-exp-owt']}.
  • Figure 5: Ablation study on fine-tuning hyperparameters $k$ and $n_y$ while holding their product constant ($k \times n_y = 32$). Evaluations are conducted in ($128, 256, 512, 1024$) sampling steps.
  • ...and 2 more figures

Theorems & Definitions (4)

  • Proposition 1: PRISM
  • proof
  • Proposition 2: Provable guarantee of PRISM-extension
  • proof