Table of Contents
Fetching ...

Optimal Inference Schedules for Masked Diffusion Models

Sitan Chen, Kevin Cong, Jerry Li

TL;DR

This work analyzes parallel token sampling in Masked Diffusion Models by characterizing the optimal unmasking schedule in terms of the distribution’s information curve ${\bf Z}(\mu)$. It establishes that the minimal expected KL divergence between the target distribution and the sampler’s output equals the left-Riemann error between ${\bf Z}$ and its approximation ${\bf Z}^{\mathbf N}$, and shows how to compute the optimal schedule via dynamic programming. The authors prove strong impossibility results for universal schedulers, and then provide practical, information-theoretic upper bounds based on Total Correlation ${\rm TC}$ and Dual Total Correlation ${\rm DTC}$, yielding schedules with $k$ on the order of $O((1+\log n) \cdot \frac{1}{\varepsilon})$ under coarse estimates. These results illuminate when sublinear, near-optimal parallel sampling is possible and offer guidance for hyperparameter-tuned scheduling in real deployments. Overall, the paper connects diffusion-based sampling, information-theoretic measures, and univariate approximation to deliver rigorous limits and actionable strategies for efficient non-autoregressive language modeling.

Abstract

A major bottleneck of standard auto-regressive large language models is that their inference process is inherently sequential, resulting in very long and costly inference times. To circumvent this, practitioners proposed a class of language models called diffusion language models, of which the masked diffusion model (MDM) is the most successful. The MDM is able to sample tokens out-of-order and, ostensibly, many tokens at once and in parallel. However, there is very limited rigorous understanding of how much parallel sampling these models can perform without noticeable degradation in their sampling performance. Prior work of Li and Cai obtained some preliminary bounds, but these are not tight for many natural classes of distributions. In this work, we give a new, exact characterization of the expected divergence between the true distribution and the sampled distribution, for any distribution and any unmasking schedule for the sampler, showing an elegant connection to the theory of univariate function approximation. By leveraging this connection, we then attain a number of novel lower and upper bounds for this problem. While the connection to function approximation in principle gives the optimal unmasking schedule for any distribution, we show that it is in general impossible to compete with it without strong a priori knowledge of the distribution, even in seemingly benign settings. However, we also demonstrate new upper bounds and new sampling schedules in terms of well-studied information-theoretic properties of the base distribution, namely, its total correlation and dual total correlation, which show that in some natural settings, one can sample in $O(log n)$ steps without any visible loss in performance, where $n$ is the total sequence length.

Optimal Inference Schedules for Masked Diffusion Models

TL;DR

This work analyzes parallel token sampling in Masked Diffusion Models by characterizing the optimal unmasking schedule in terms of the distribution’s information curve . It establishes that the minimal expected KL divergence between the target distribution and the sampler’s output equals the left-Riemann error between and its approximation , and shows how to compute the optimal schedule via dynamic programming. The authors prove strong impossibility results for universal schedulers, and then provide practical, information-theoretic upper bounds based on Total Correlation and Dual Total Correlation , yielding schedules with on the order of under coarse estimates. These results illuminate when sublinear, near-optimal parallel sampling is possible and offer guidance for hyperparameter-tuned scheduling in real deployments. Overall, the paper connects diffusion-based sampling, information-theoretic measures, and univariate approximation to deliver rigorous limits and actionable strategies for efficient non-autoregressive language modeling.

Abstract

A major bottleneck of standard auto-regressive large language models is that their inference process is inherently sequential, resulting in very long and costly inference times. To circumvent this, practitioners proposed a class of language models called diffusion language models, of which the masked diffusion model (MDM) is the most successful. The MDM is able to sample tokens out-of-order and, ostensibly, many tokens at once and in parallel. However, there is very limited rigorous understanding of how much parallel sampling these models can perform without noticeable degradation in their sampling performance. Prior work of Li and Cai obtained some preliminary bounds, but these are not tight for many natural classes of distributions. In this work, we give a new, exact characterization of the expected divergence between the true distribution and the sampled distribution, for any distribution and any unmasking schedule for the sampler, showing an elegant connection to the theory of univariate function approximation. By leveraging this connection, we then attain a number of novel lower and upper bounds for this problem. While the connection to function approximation in principle gives the optimal unmasking schedule for any distribution, we show that it is in general impossible to compete with it without strong a priori knowledge of the distribution, even in seemingly benign settings. However, we also demonstrate new upper bounds and new sampling schedules in terms of well-studied information-theoretic properties of the base distribution, namely, its total correlation and dual total correlation, which show that in some natural settings, one can sample in steps without any visible loss in performance, where is the total sequence length.

Paper Structure

This paper contains 20 sections, 20 theorems, 91 equations, 1 figure.

Key Result

Theorem 1.4

Let $\mu$ be any distribution over $\Sigma^n$, and let $1 \leqslant k \leqslant n$. Let $\mathbf{N}^{*,k}$ be the solution to Eq. eq:bestkstep for $\mu$'s information curve ${\bf Z} = {\bf Z}(\mu)$. Then for any unmasking schedule $s_1,\ldots,s_k$, the expected KL error is given by In particular, the schedule that minimizes the expected KL error is

Figures (1)

  • Figure 1: Discrete curve $\mathbf Z$ (blue) and left Riemann approximation $\mathbf Z^{\mathbf N}$ (red) for a sample $Z_i$ curve. The latter extends beyond the $Z_j$ curve to $n+1$ to show the final rectangle $Z_n - Z_{N_{k-1} + 1}$; note that this term is not present in a standard left Riemann approximation. Light blue background rectangles represent the Riemann approximation terms. The total area is $\|{\bf Z} - {\bf Z}^N\|_{L^1}$.

Theorems & Definitions (52)

  • Definition 1.1: Expected KL error
  • Definition 1.2: Left Riemann approximation
  • Definition 1.3: Average mutual information curve
  • Theorem 1.4: Optimal schedule given by best step approximation
  • Theorem 1.5: Uniform versus code is hard, see Theorem \ref{['thm:warmup']} for formal statement
  • Theorem 1.6: Hardness for general information curves, see Theorem \ref{['thm:elevate']} for formal statement
  • Definition 1.7: Total Correlation and Dual Total Correlation
  • Lemma 1.8
  • Theorem 1.9: Iteration complexity depending on $\mathrm{TC}, \mathrm{DTC}$
  • Theorem 1.10: Austin's iteration complexity bound austin2020multi
  • ...and 42 more