Table of Contents
Fetching ...

Beyond Masked and Unmasked: Discrete Diffusion Models via Partial Masking

Chen-Hao Chao, Wei-Fang Sun, Hanwen Liang, Chun-Yi Lee, Rahul G. Krishnan

TL;DR

This work addresses inefficiency in discrete diffusion models (MDM) caused by idle steps in the reverse process. It introduces MDM-Prime, which augments tokens with intermediate sub-token states via an invertible mapping $f$ to a base-$b$ encoding, enabling a finer-grained denoising process and reducing idle steps. The method derives a variational objective for the extended latent space and presents a lightweight decoder/encoder design that leverages joint sub-token distributions with a carry-over mechanism to maintain validity of samples. Empirically, MDLM-Prime achieves a perplexity of $15.36$ on OpenWebText and competitive Fréchet Inception Distance (FID) scores of $3.26$ on CIFAR-10 and $6.98$ on ImageNet-32, outperforming several autoregressive and prior MDM variants while remaining order-agnostic. The results suggest that discrete diffusion with partial masking can rival autoregressive and continuous diffusion approaches in both text and image domains, offering scalable, non-autoregressive generation with strong likelihoods and sample quality.

Abstract

Masked diffusion models (MDM) are powerful generative models for discrete data that generate samples by progressively unmasking tokens in a sequence. Each token can take one of two states: masked or unmasked. We observe that token sequences often remain unchanged between consecutive sampling steps; consequently, the model repeatedly processes identical inputs, leading to redundant computation. To address this inefficiency, we propose the Partial masking scheme (Prime), which augments MDM by allowing tokens to take intermediate states interpolated between the masked and unmasked states. This design enables the model to make predictions based on partially observed token information, and facilitates a fine-grained denoising process. We derive a variational training objective and introduce a simple architectural design to accommodate intermediate-state inputs. Our method demonstrates superior performance across a diverse set of generative modeling tasks. On text data, it achieves a perplexity of 15.36 on OpenWebText, outperforming previous MDM (21.52), autoregressive models (17.54), and their hybrid variants (17.58), without relying on an autoregressive formulation. On image data, it attains competitive FID scores of 3.26 on CIFAR-10 and 6.98 on ImageNet-32, comparable to leading continuous generative models.

Beyond Masked and Unmasked: Discrete Diffusion Models via Partial Masking

TL;DR

This work addresses inefficiency in discrete diffusion models (MDM) caused by idle steps in the reverse process. It introduces MDM-Prime, which augments tokens with intermediate sub-token states via an invertible mapping to a base- encoding, enabling a finer-grained denoising process and reducing idle steps. The method derives a variational objective for the extended latent space and presents a lightweight decoder/encoder design that leverages joint sub-token distributions with a carry-over mechanism to maintain validity of samples. Empirically, MDLM-Prime achieves a perplexity of on OpenWebText and competitive Fréchet Inception Distance (FID) scores of on CIFAR-10 and on ImageNet-32, outperforming several autoregressive and prior MDM variants while remaining order-agnostic. The results suggest that discrete diffusion with partial masking can rival autoregressive and continuous diffusion approaches in both text and image domains, offering scalable, non-autoregressive generation with strong likelihoods and sample quality.

Abstract

Masked diffusion models (MDM) are powerful generative models for discrete data that generate samples by progressively unmasking tokens in a sequence. Each token can take one of two states: masked or unmasked. We observe that token sequences often remain unchanged between consecutive sampling steps; consequently, the model repeatedly processes identical inputs, leading to redundant computation. To address this inefficiency, we propose the Partial masking scheme (Prime), which augments MDM by allowing tokens to take intermediate states interpolated between the masked and unmasked states. This design enables the model to make predictions based on partially observed token information, and facilitates a fine-grained denoising process. We derive a variational training objective and introduce a simple architectural design to accommodate intermediate-state inputs. Our method demonstrates superior performance across a diverse set of generative modeling tasks. On text data, it achieves a perplexity of 15.36 on OpenWebText, outperforming previous MDM (21.52), autoregressive models (17.54), and their hybrid variants (17.58), without relying on an autoregressive formulation. On image data, it attains competitive FID scores of 3.26 on CIFAR-10 and 6.98 on ImageNet-32, comparable to leading continuous generative models.

Paper Structure

This paper contains 53 sections, 4 theorems, 32 equations, 28 figures, 10 tables.

Key Result

Proposition A.1

Let $L$ be the token sequence length, $T$ be the total number of discretized timesteps for the sampling process, and $\alpha_t\in [0,1]$ be a strictly decreasing scheduling function in $t\in [0,1]$. Suppose the sampling timesteps are indexed by $k\in \{0,\cdots,T-1\}$, the expected number of idle st

Figures (28)

  • Figure 1: Number of idle steps during the reverse diffusion processes of MDM and MDM-Prime. The results are averaged over ten runs. $\ell$ is the sub-token sequence length.
  • Figure 2: An illustrative example of (a) standard MDM and (b) MDM-Prime. Each token and its corresponding sub-token sequence (constructed via base-$b$ encoding) can take one of three states: unmasked, masked, or intermediate. The masked and intermediate states serve as the latent representations produced by the forward diffusion process. This example contains $C = 4$ possible token classes, labeled as 'bird’, 'cat’, 'dog’, and 'frog’. $\ell = 2$ indicates that each token is represented using two sub-tokens, and $b = \sqrt[\ell]{C} = 2$ denotes the number of classes per sub-token. The symbol m represents a masked token or a masked sub-token. The bottom-right sections of (a) and (b) illustrate the state transition trees. MDM-Prime supports transitions through intermediate states while retaining the ability to directly reach unmasked states. The bottom-left portions depict the sampling process for a token sequence of length $L = 4$. In (a), an idle step occurs between steps 3 and 4. In contrast, (b) demonstrates a sampling process without idle steps, which leads to improved model utilization.
  • Figure 3: Illustration of the data and parameterized pmf with an invertible $f$. 'Param. pmf' denotes the parameterized pmf captured using an MDM with parameter $\theta$. In this example, $\ell=2$.
  • Figure 4: Distributions modeled by MDM-Prime using (a) independent and (b) joint parameterizations. Models are trained on a two-dimensional synthetic dataset with $\bm{x}_0\in[0,\cdots,511]^2$ representing the coordinate of the figure ($512\times 512$). Brighter regions indicate higher probabilities. Experimental details are offered in Appendix \ref{['apx:setups:toy']}.
  • Figure 5: An illustration of the proposed carry-over parameterization technique. In this example, $C = 7$, $\ell = 3$, and $b = 2$. The conditional distribution $p_\theta (\bm{y}_0^i |\bm{y}_t)$ is defined in Eq. (\ref{['eq:methodology:parameterization']}). Softmax distributions are formed by normalizing the corresponding logit outputs highlighted in yellow.
  • ...and 23 more figures

Theorems & Definitions (8)

  • Proposition A.1
  • proof
  • Proposition A.2
  • proof
  • Proposition A.3
  • proof
  • Proposition A.4
  • proof