Masked Diffusion Models are Secretly Learned-Order Autoregressive Models
Prateek Garg, Bhavya Kohli, Sunita Sarawagi
TL;DR
This work extends Masked Diffusion Models (MDMs) to learn favorable decoding orders in discrete data by employing multivariate noise schedules. The authors prove that the continuous-time ELBO of MDMs decomposes into a weighted auto-regressive loss over possible orders, with the schedule defining the order distribution and enabling state-independent learning of decoding sequences. They establish an exact correspondence between token ordering and inference-time schedules, and validate the theory on tabular data where learned schedules modestly improve validation loss while maintaining competitive data-fidelity metrics. Overall, the approach highlights how diffusion-based methods can implicitly discover and optimize token ordering, offering a principled path toward learnable, order-aware generative models for discrete domains with potential implications for speed-accuracy trade-offs and structure discovery in complex data.
Abstract
Masked Diffusion Models (MDMs) have emerged as one of the most promising paradigms for generative modeling over discrete domains. It is known that MDMs effectively train to decode tokens in a random order, and that this ordering has significant performance implications in practice. This observation raises a fundamental question: can we design a training framework that optimizes for a favorable decoding order? We answer this in the affirmative, showing that the continuous-time variational objective of MDMs, when equipped with multivariate noise schedules, can identify and optimize for a decoding order during training. We establish a direct correspondence between decoding order and the multivariate noise schedule and show that this setting breaks invariance of the MDM objective to the noise schedule. Furthermore, we prove that the MDM objective decomposes precisely into a weighted auto-regressive losses over these orders, which establishes them as auto-regressive models with learnable orders.
