Table of Contents
Fetching ...

Any-Order Flexible Length Masked Diffusion

Jaeyeon Kim, Lee Cheuk-Kit, Carles Domingo-Enrich, Yilun Du, Sham Kakade, Timothy Ngotiaoco, Sitan Chen, Michael Albergo

TL;DR

The paper tackles the limitation of Masked Diffusion Models in handling variable-length sequences and token insertions. It introduces FlexMDM, a discrete diffusion model built on a joint interpolant that jointly models insertions and unmasking within a continuous-time Markov chain, preserving any-order inference. The authors provide training losses and rate matrices with variational guarantees, and demonstrate that FlexMDM achieves higher fidelity in length modeling and notable improvements on planning tasks, plus scalable 8B setups via retrofit of pretrained MDMs. They show the approach transfers from MDMs to FlexMDMs with modest compute, achieving gains on GSM8K math and code infilling.

Abstract

Masked diffusion models (MDMs) have recently emerged as a promising alternative to autoregressive models over discrete domains. MDMs generate sequences in an any-order, parallel fashion, enabling fast inference and strong performance on non-causal tasks. However, a crucial limitation is that they do not support token insertions and are thus limited to fixed-length generations. To this end, we introduce Flexible Masked Diffusion Models (FlexMDMs), a discrete diffusion paradigm that simultaneously can model sequences of flexible length while provably retaining MDMs' flexibility of any-order inference. Grounded in an extension of the stochastic interpolant framework, FlexMDMs generate sequences by inserting mask tokens and unmasking them. Empirically, we show that FlexMDMs match MDMs in perplexity while modeling length statistics with much higher fidelity. On a synthetic maze planning task, they achieve $\approx 60 \%$ higher success rate than MDM baselines. Finally, we show pretrained MDMs can easily be retrofitted into FlexMDMs: on 16 H100s, it takes only three days to fine-tune LLaDA-8B into a FlexMDM, achieving superior performance on math (GSM8K, $58\% \to 67\%$) and code infilling performance ($52\% \to 65\%$).

Any-Order Flexible Length Masked Diffusion

TL;DR

The paper tackles the limitation of Masked Diffusion Models in handling variable-length sequences and token insertions. It introduces FlexMDM, a discrete diffusion model built on a joint interpolant that jointly models insertions and unmasking within a continuous-time Markov chain, preserving any-order inference. The authors provide training losses and rate matrices with variational guarantees, and demonstrate that FlexMDM achieves higher fidelity in length modeling and notable improvements on planning tasks, plus scalable 8B setups via retrofit of pretrained MDMs. They show the approach transfers from MDMs to FlexMDMs with modest compute, achieving gains on GSM8K math and code infilling.

Abstract

Masked diffusion models (MDMs) have recently emerged as a promising alternative to autoregressive models over discrete domains. MDMs generate sequences in an any-order, parallel fashion, enabling fast inference and strong performance on non-causal tasks. However, a crucial limitation is that they do not support token insertions and are thus limited to fixed-length generations. To this end, we introduce Flexible Masked Diffusion Models (FlexMDMs), a discrete diffusion paradigm that simultaneously can model sequences of flexible length while provably retaining MDMs' flexibility of any-order inference. Grounded in an extension of the stochastic interpolant framework, FlexMDMs generate sequences by inserting mask tokens and unmasking them. Empirically, we show that FlexMDMs match MDMs in perplexity while modeling length statistics with much higher fidelity. On a synthetic maze planning task, they achieve higher success rate than MDM baselines. Finally, we show pretrained MDMs can easily be retrofitted into FlexMDMs: on 16 H100s, it takes only three days to fine-tune LLaDA-8B into a FlexMDM, achieving superior performance on math (GSM8K, ) and code infilling performance ().

Paper Structure

This paper contains 54 sections, 14 theorems, 54 equations, 6 figures, 1 table, 2 algorithms.

Key Result

Proposition 1

The loss $\mathcal{L}_\theta$ in Eq. (eq:FlexMDM_loss) is uniquely minimized at

Figures (6)

  • Figure 1: Flexible Masked Diffusion Model (FlexMDM) addresses MDMs’ inability to handle variable-length sequences and token insertion while preserving any-order generation power. At each step, FlexMDM performs insertion and unmasking by predicting the expected number of mask tokens to insert ($g_\theta$) and the posterior over clean tokens ($f_\theta$), respectively.
  • Figure 2: To draw a sample $x_t$, one can equivalently sample the clean sequence $x_1\sim p_1$, draw unmasking times, and then accordingly unmask or mask each coordinate's token.
  • Figure 3: Left (FlexMDM interpolant). To draw a sample $x_t$, one can equivalently draw a sample $x_1 \sim p_1$, and for each token unmask, mask, or remove it according to the unmasking and insertion times $(T_1^i, T_2^i)$. An auxiliary interpolant $s_t$ gives closed-form expressions for the FlexMDM rate matrices. Right (FlexMDM Inference). Learned unmasking posterior and insertion expectation are later used at inference.
  • Figure 4: Subroutine 1: VLMDM inference
  • Figure 5: FlexMDM performance exhibits superior scaling when more sampling steps are allocated.
  • ...and 1 more figures

Theorems & Definitions (30)

  • Proposition 1: FlexMDM training loss
  • Proposition 2: FlexMDM Rate Matrix
  • Proposition 3: Any-order inference, informal
  • Definition C.1: Discrete Stochastic Interpolant and Interpolating Rate
  • Proposition C.1: Target Rate
  • proof
  • Definition C.2: The Masked Diffusion Interpolant
  • Proposition C.2: The Masked Diffusion Interpolating Rate
  • proof
  • Proposition 4: The Masked Diffusion Target Rate
  • ...and 20 more