Table of Contents
Fetching ...

Variational Masked Diffusion Models

Yichi Zhang, Alex Schwing, Zhizhen Zhao

TL;DR

Addresses the limitation of standard masked diffusion in modeling dependencies among concurrently predicted tokens by introducing Variational Masked Diffusion (VMD). VMD injects a global latent variable ${z}$ to capture joint token distributions and derives a variational objective ${L_{VMD}}$, with a subsequent Block Diffusion extension that scales to blocks of tokens. Across synthetic data, Sudoku, and text, VMD demonstrates improved dependency modeling and generation quality, outperforming standard masked diffusion baselines and approaching autoregressive-like performance with competitive efficiency. The work provides a principled integration of variational inference into masked diffusion and releases code for reproducibility.

Abstract

Masked diffusion models have recently emerged as a flexible framework for discrete generative modeling. However, a key limitation of standard masked diffusion is its inability to effectively capture dependencies among tokens that are predicted concurrently, leading to degraded generation quality when dependencies among tokens are important. To explicitly model dependencies among tokens, we propose Variational Masked Diffusion (VMD), a framework that introduces latent variables into the masked diffusion process. Through controlled experiments on synthetic datasets, we demonstrate that VMD successfully learns dependencies that conventional masked diffusion fails to capture. We further validate the effectiveness of our approach on Sudoku puzzles and text datasets, where learning of dependencies among tokens improves global consistency. Across these domains, VMD enhances both generation quality and dependency awareness, highlighting the value of integrating variational inference into masked diffusion. Our code is available at: https://riccizz.github.io/VMD.

Variational Masked Diffusion Models

TL;DR

Addresses the limitation of standard masked diffusion in modeling dependencies among concurrently predicted tokens by introducing Variational Masked Diffusion (VMD). VMD injects a global latent variable to capture joint token distributions and derives a variational objective , with a subsequent Block Diffusion extension that scales to blocks of tokens. Across synthetic data, Sudoku, and text, VMD demonstrates improved dependency modeling and generation quality, outperforming standard masked diffusion baselines and approaching autoregressive-like performance with competitive efficiency. The work provides a principled integration of variational inference into masked diffusion and releases code for reproducibility.

Abstract

Masked diffusion models have recently emerged as a flexible framework for discrete generative modeling. However, a key limitation of standard masked diffusion is its inability to effectively capture dependencies among tokens that are predicted concurrently, leading to degraded generation quality when dependencies among tokens are important. To explicitly model dependencies among tokens, we propose Variational Masked Diffusion (VMD), a framework that introduces latent variables into the masked diffusion process. Through controlled experiments on synthetic datasets, we demonstrate that VMD successfully learns dependencies that conventional masked diffusion fails to capture. We further validate the effectiveness of our approach on Sudoku puzzles and text datasets, where learning of dependencies among tokens improves global consistency. Across these domains, VMD enhances both generation quality and dependency awareness, highlighting the value of integrating variational inference into masked diffusion. Our code is available at: https://riccizz.github.io/VMD.

Paper Structure

This paper contains 23 sections, 11 equations, 7 figures, 5 tables, 4 algorithms.

Figures (7)

  • Figure 1: Conceptual overview of VMD. (a) Training. The encoder and mask predictor are trained on text with random masks applied independently to all tokens at the same ratio $t \sim U[0, 1]$. $x_0^{i}$ and $x_0^{j}$ are conditionally independent given the latent variable $z$ (b) Sampling. Our VMD uses the latent sample $z$ to achieve concurrent mask prediction and recover all tokens at $t = 0$ from the fully masked sequence at $t = 1$ with a flexible remasking strategy.
  • Figure 2: Results on controlled synthetic data with 2 tokens under the non-uniform setting. (a) Ground truth distribution, (b) baseline masked diffusion with one-step generation, and (c) our VMD with one-step generation. While the baseline degenerates into nearly uniform random guessing and fails to capture the underlying dependency, VMD accurately recovers the true distribution, demonstrating its ability to learn token-to-token correlations beyond conditional independence.
  • Figure 3: Results on controlled synthetic data with 2 tokens under the deterministic setting. (a) Ground truth distribution, (b) MDM: baseline masked diffusion with one-step generation, and (c) our VMD with one-step generation. MDM one-step generation is identical to random guessing of each token independently, failing to reflect the true correlations, while VMD closely recovers the ground-truth structure.
  • Figure 4: Results on controlled synthetic data with 2 tokens under the varying correlation setting. (a) Ground truth distribution, (b) MDM: baseline masked diffusion with one-step generation, and (c) our VMD with one-step generation. The first row is for data generated with $p=0.3$, and the second row is for data generated with $p=0.8$. The baseline fails to model the data distribution in one-step generation in both cases, producing nearly uniform predictions, while VMD successfully recovers the underlying distributions.
  • Figure 5: Accuracy ($\uparrow$) during training on the Sudoku puzzle experiment with NFE=20 under the top probability remasking strategy. VMD consistently outperforms the baseline (MDM) across training iterations. The final accuracy reaches 18.97% for the Baseline and 82.03% for VMD.
  • ...and 2 more figures