Fine-Tuning Masked Diffusion for Provable Self-Correction
Jaeyeon Kim, Seunggeun Kim, Taekyun Lee, David Z. Pan, Hyeji Kim, Sham Kakade, Sitan Chen
TL;DR
This work tackles the limited self-correction ability of Masked Diffusion Models by introducing PRISM, a plug-in, inference-time remasking framework with a provable self-correction loss. PRISM adds a lightweight adapter to pretrained MDMs to learn per-token quality and jointly model unmasking posteriors, enabling tokens to be remasked and revised during inference without retraining the entire model. The approach yields consistent improvements across domains—Sudoku, unconditional text with a 170M MDM, and code with LLaDA-8B—while maintaining efficiency and compatibility with existing architectures. The results demonstrate that PRISM learns calibrated per-token quality signals and enables effective self-correction with modest fine-tuning, highlighting a scalable path toward robust discrete sequence generation.
Abstract
A natural desideratum for generative models is self-correction--detecting and revising low-quality tokens at inference. While Masked Diffusion Models (MDMs) have emerged as a promising approach for generative modeling in discrete spaces, their capacity for self-correction remains poorly understood. Prior attempts to incorporate self-correction into MDMs either require overhauling MDM architectures/training or rely on imprecise proxies for token quality, limiting their applicability. Motivated by this, we introduce PRISM--Plug-in Remasking for Inference-time Self-correction of Masked Diffusions--a lightweight, model-agnostic approach that applies to any pretrained MDM. Theoretically, PRISM defines a self-correction loss that provably learns per-token quality scores, without RL or a verifier. These quality scores are computed in the same forward pass with MDM and used to detect low-quality tokens. Empirically, PRISM advances MDM inference across domains and scales: Sudoku; unconditional text (170M); and code with LLaDA (8B).
