Optimal Inference Schedules for Masked Diffusion Models
Sitan Chen, Kevin Cong, Jerry Li
TL;DR
This work analyzes parallel token sampling in Masked Diffusion Models by characterizing the optimal unmasking schedule in terms of the distribution’s information curve ${\bf Z}(\mu)$. It establishes that the minimal expected KL divergence between the target distribution and the sampler’s output equals the left-Riemann error between ${\bf Z}$ and its approximation ${\bf Z}^{\mathbf N}$, and shows how to compute the optimal schedule via dynamic programming. The authors prove strong impossibility results for universal schedulers, and then provide practical, information-theoretic upper bounds based on Total Correlation ${\rm TC}$ and Dual Total Correlation ${\rm DTC}$, yielding schedules with $k$ on the order of $O((1+\log n) \cdot \frac{1}{\varepsilon})$ under coarse estimates. These results illuminate when sublinear, near-optimal parallel sampling is possible and offer guidance for hyperparameter-tuned scheduling in real deployments. Overall, the paper connects diffusion-based sampling, information-theoretic measures, and univariate approximation to deliver rigorous limits and actionable strategies for efficient non-autoregressive language modeling.
Abstract
A major bottleneck of standard auto-regressive large language models is that their inference process is inherently sequential, resulting in very long and costly inference times. To circumvent this, practitioners proposed a class of language models called diffusion language models, of which the masked diffusion model (MDM) is the most successful. The MDM is able to sample tokens out-of-order and, ostensibly, many tokens at once and in parallel. However, there is very limited rigorous understanding of how much parallel sampling these models can perform without noticeable degradation in their sampling performance. Prior work of Li and Cai obtained some preliminary bounds, but these are not tight for many natural classes of distributions. In this work, we give a new, exact characterization of the expected divergence between the true distribution and the sampled distribution, for any distribution and any unmasking schedule for the sampler, showing an elegant connection to the theory of univariate function approximation. By leveraging this connection, we then attain a number of novel lower and upper bounds for this problem. While the connection to function approximation in principle gives the optimal unmasking schedule for any distribution, we show that it is in general impossible to compete with it without strong a priori knowledge of the distribution, even in seemingly benign settings. However, we also demonstrate new upper bounds and new sampling schedules in terms of well-studied information-theoretic properties of the base distribution, namely, its total correlation and dual total correlation, which show that in some natural settings, one can sample in $O(log n)$ steps without any visible loss in performance, where $n$ is the total sequence length.
