Beyond Masked and Unmasked: Discrete Diffusion Models via Partial Masking
Chen-Hao Chao, Wei-Fang Sun, Hanwen Liang, Chun-Yi Lee, Rahul G. Krishnan
TL;DR
This work addresses inefficiency in discrete diffusion models (MDM) caused by idle steps in the reverse process. It introduces MDM-Prime, which augments tokens with intermediate sub-token states via an invertible mapping $f$ to a base-$b$ encoding, enabling a finer-grained denoising process and reducing idle steps. The method derives a variational objective for the extended latent space and presents a lightweight decoder/encoder design that leverages joint sub-token distributions with a carry-over mechanism to maintain validity of samples. Empirically, MDLM-Prime achieves a perplexity of $15.36$ on OpenWebText and competitive Fréchet Inception Distance (FID) scores of $3.26$ on CIFAR-10 and $6.98$ on ImageNet-32, outperforming several autoregressive and prior MDM variants while remaining order-agnostic. The results suggest that discrete diffusion with partial masking can rival autoregressive and continuous diffusion approaches in both text and image domains, offering scalable, non-autoregressive generation with strong likelihoods and sample quality.
Abstract
Masked diffusion models (MDM) are powerful generative models for discrete data that generate samples by progressively unmasking tokens in a sequence. Each token can take one of two states: masked or unmasked. We observe that token sequences often remain unchanged between consecutive sampling steps; consequently, the model repeatedly processes identical inputs, leading to redundant computation. To address this inefficiency, we propose the Partial masking scheme (Prime), which augments MDM by allowing tokens to take intermediate states interpolated between the masked and unmasked states. This design enables the model to make predictions based on partially observed token information, and facilitates a fine-grained denoising process. We derive a variational training objective and introduce a simple architectural design to accommodate intermediate-state inputs. Our method demonstrates superior performance across a diverse set of generative modeling tasks. On text data, it achieves a perplexity of 15.36 on OpenWebText, outperforming previous MDM (21.52), autoregressive models (17.54), and their hybrid variants (17.58), without relying on an autoregressive formulation. On image data, it attains competitive FID scores of 3.26 on CIFAR-10 and 6.98 on ImageNet-32, comparable to leading continuous generative models.
