Any-Order Flexible Length Masked Diffusion
Jaeyeon Kim, Lee Cheuk-Kit, Carles Domingo-Enrich, Yilun Du, Sham Kakade, Timothy Ngotiaoco, Sitan Chen, Michael Albergo
TL;DR
The paper tackles the limitation of Masked Diffusion Models in handling variable-length sequences and token insertions. It introduces FlexMDM, a discrete diffusion model built on a joint interpolant that jointly models insertions and unmasking within a continuous-time Markov chain, preserving any-order inference. The authors provide training losses and rate matrices with variational guarantees, and demonstrate that FlexMDM achieves higher fidelity in length modeling and notable improvements on planning tasks, plus scalable 8B setups via retrofit of pretrained MDMs. They show the approach transfers from MDMs to FlexMDMs with modest compute, achieving gains on GSM8K math and code infilling.
Abstract
Masked diffusion models (MDMs) have recently emerged as a promising alternative to autoregressive models over discrete domains. MDMs generate sequences in an any-order, parallel fashion, enabling fast inference and strong performance on non-causal tasks. However, a crucial limitation is that they do not support token insertions and are thus limited to fixed-length generations. To this end, we introduce Flexible Masked Diffusion Models (FlexMDMs), a discrete diffusion paradigm that simultaneously can model sequences of flexible length while provably retaining MDMs' flexibility of any-order inference. Grounded in an extension of the stochastic interpolant framework, FlexMDMs generate sequences by inserting mask tokens and unmasking them. Empirically, we show that FlexMDMs match MDMs in perplexity while modeling length statistics with much higher fidelity. On a synthetic maze planning task, they achieve $\approx 60 \%$ higher success rate than MDM baselines. Finally, we show pretrained MDMs can easily be retrofitted into FlexMDMs: on 16 H100s, it takes only three days to fine-tune LLaDA-8B into a FlexMDM, achieving superior performance on math (GSM8K, $58\% \to 67\%$) and code infilling performance ($52\% \to 65\%$).
