Enabling Approximate Joint Sampling in Diffusion LMs
Parikshit Bansal, Sujay Sanghavi
TL;DR
This work tackles the mismatch between joint distributions and parallel sampling in masked diffusion language models by introducing ADJUST, a lightweight single-layer transformer that sits atop a frozen diffusion LM to realize approximate joint sampling. By performing one base forward pass followed by $K-1$ passes of the ADJUST layer, the method unmasked $K$ tokens per step in a way that conditions each token on the identities of previously unmasked tokens. Training minimizes the KL divergence between the true joint distribution $p_*$ and the ADJUST-driven approximation $p_{\text{ADJUST}}$, using trajectories from the base model’s denoising process. Empirically, ADJUST yields substantial gains in MAUVE and GSM8K across pretrained and instruction-tuned models, with four tokens per step achieving near joint-distribution quality while maintaining competitive throughput, and it outperforms naive parallel sampling on downstream tasks and on the ParallelBench benchmark. The approach offers a practical path to faster diffusion-based generation without external autoregressive verifications, potentially enabling more efficient deployment of diffusion LMs in real-world applications.
Abstract
In autoregressive language models, each token is sampled by conditioning on all the past tokens; the overall string has thus been sampled from the correct underlying joint distribution represented by the model. In contrast, masked diffusion language models generate text by unmasking tokens out of order and potentially in parallel. Generating an overall string sampled from the correct underlying joint distribution would (again) require exactly one token unmasking in every full-model forward pass. The more tokens unmasked in parallel, the further away the string is from the true joint; this can be seen in the resulting drop in accuracy (but, increase in speed). In this paper we devise a way to {\em approximately} sample multiple tokens from the joint distribution in a single full-model forward pass; we do so by developing a new lightweight single-layer ``sampler" on top of an existing large diffusion LM. One forward pass of the full model can now be followed by multiple forward passes of only this sampler layer, to yield multiple unmasked tokens. Our sampler is trained to mimic exact joint sampling from the (frozen) full model. We show the effectiveness of our approximate joint sampling for both pretrained-only (Dream-7B-Base, Llada-7B-Base) and instruction-tuned (Dream-7B-Instruct, Dream-7B-Coder) models on language modeling and math \& coding tasks. When four tokens are unmasked for each full-model denoising step, our sampling algorithm achieves a MAUVE score of 0.87 (vs marginal baseline of 0.31) with respect to the true joint distribution.
