Table of Contents
Fetching ...

Enabling Approximate Joint Sampling in Diffusion LMs

Parikshit Bansal, Sujay Sanghavi

TL;DR

This work tackles the mismatch between joint distributions and parallel sampling in masked diffusion language models by introducing ADJUST, a lightweight single-layer transformer that sits atop a frozen diffusion LM to realize approximate joint sampling. By performing one base forward pass followed by $K-1$ passes of the ADJUST layer, the method unmasked $K$ tokens per step in a way that conditions each token on the identities of previously unmasked tokens. Training minimizes the KL divergence between the true joint distribution $p_*$ and the ADJUST-driven approximation $p_{\text{ADJUST}}$, using trajectories from the base model’s denoising process. Empirically, ADJUST yields substantial gains in MAUVE and GSM8K across pretrained and instruction-tuned models, with four tokens per step achieving near joint-distribution quality while maintaining competitive throughput, and it outperforms naive parallel sampling on downstream tasks and on the ParallelBench benchmark. The approach offers a practical path to faster diffusion-based generation without external autoregressive verifications, potentially enabling more efficient deployment of diffusion LMs in real-world applications.

Abstract

In autoregressive language models, each token is sampled by conditioning on all the past tokens; the overall string has thus been sampled from the correct underlying joint distribution represented by the model. In contrast, masked diffusion language models generate text by unmasking tokens out of order and potentially in parallel. Generating an overall string sampled from the correct underlying joint distribution would (again) require exactly one token unmasking in every full-model forward pass. The more tokens unmasked in parallel, the further away the string is from the true joint; this can be seen in the resulting drop in accuracy (but, increase in speed). In this paper we devise a way to {\em approximately} sample multiple tokens from the joint distribution in a single full-model forward pass; we do so by developing a new lightweight single-layer ``sampler" on top of an existing large diffusion LM. One forward pass of the full model can now be followed by multiple forward passes of only this sampler layer, to yield multiple unmasked tokens. Our sampler is trained to mimic exact joint sampling from the (frozen) full model. We show the effectiveness of our approximate joint sampling for both pretrained-only (Dream-7B-Base, Llada-7B-Base) and instruction-tuned (Dream-7B-Instruct, Dream-7B-Coder) models on language modeling and math \& coding tasks. When four tokens are unmasked for each full-model denoising step, our sampling algorithm achieves a MAUVE score of 0.87 (vs marginal baseline of 0.31) with respect to the true joint distribution.

Enabling Approximate Joint Sampling in Diffusion LMs

TL;DR

This work tackles the mismatch between joint distributions and parallel sampling in masked diffusion language models by introducing ADJUST, a lightweight single-layer transformer that sits atop a frozen diffusion LM to realize approximate joint sampling. By performing one base forward pass followed by passes of the ADJUST layer, the method unmasked tokens per step in a way that conditions each token on the identities of previously unmasked tokens. Training minimizes the KL divergence between the true joint distribution and the ADJUST-driven approximation , using trajectories from the base model’s denoising process. Empirically, ADJUST yields substantial gains in MAUVE and GSM8K across pretrained and instruction-tuned models, with four tokens per step achieving near joint-distribution quality while maintaining competitive throughput, and it outperforms naive parallel sampling on downstream tasks and on the ParallelBench benchmark. The approach offers a practical path to faster diffusion-based generation without external autoregressive verifications, potentially enabling more efficient deployment of diffusion LMs in real-world applications.

Abstract

In autoregressive language models, each token is sampled by conditioning on all the past tokens; the overall string has thus been sampled from the correct underlying joint distribution represented by the model. In contrast, masked diffusion language models generate text by unmasking tokens out of order and potentially in parallel. Generating an overall string sampled from the correct underlying joint distribution would (again) require exactly one token unmasking in every full-model forward pass. The more tokens unmasked in parallel, the further away the string is from the true joint; this can be seen in the resulting drop in accuracy (but, increase in speed). In this paper we devise a way to {\em approximately} sample multiple tokens from the joint distribution in a single full-model forward pass; we do so by developing a new lightweight single-layer ``sampler" on top of an existing large diffusion LM. One forward pass of the full model can now be followed by multiple forward passes of only this sampler layer, to yield multiple unmasked tokens. Our sampler is trained to mimic exact joint sampling from the (frozen) full model. We show the effectiveness of our approximate joint sampling for both pretrained-only (Dream-7B-Base, Llada-7B-Base) and instruction-tuned (Dream-7B-Instruct, Dream-7B-Coder) models on language modeling and math \& coding tasks. When four tokens are unmasked for each full-model denoising step, our sampling algorithm achieves a MAUVE score of 0.87 (vs marginal baseline of 0.31) with respect to the true joint distribution.

Paper Structure

This paper contains 38 sections, 5 equations, 5 figures, 6 tables, 2 algorithms.

Figures (5)

  • Figure 1: In this figure, we use Dream-7B-Base to generate a token string of length 128 tokens starting from an all masked string (kindly refer to Sec \ref{['subsec:language']} for details). We report the negative log-likelihood (NLL) and the MAUVE score for the generated strings. We vary the number of tokens sampled per forward pass of the diffusion model (denoising step) from one to four. We observe an increase in NLL and a decrease in the MAUVE score as more tokens are generated in parallel. We argue this is because generating multiple tokens in parallel samples from a distribution different from the true joint distribution. For a given number of tokens generated per step, our joint sampler ADJUST achieves the best NLL and MAUVE (as compared to baseline methods like naive parallel sampling, and energy-based models).
  • Figure 2: In this figure we report runtime (tokens produced per second) with the corresponding MAUVE and GSM8k scores for Dream-Base-7B and Dream-Instruct-7B respectively. We vary the number of tokens sampled for a single diffusion denoising step to vary the token throughput (from one to four tokens per step, each corresponding to a point on the curve). Increasing the number of tokens sampled per step leads to a decrease in both MAUVE and GSM8k performance, as expected. We show that our joint sampler leads to only slight reduction in throughput compared to the naive parallel sampling. Importantly, for a given target throughput, sampling using our joint approximation outperforms using naive parallel sampling. For example, for parallel decoding of four tokens in each diffusion pass, our joint approximation is only 20-25% slower than parallel decoding while being 16 percentage points more accurate on GSM8K and 0.5 points higher on MAUVE.
  • Figure 3: In this figure we illustrate our method ADJUST. The sub-figure on the left shows naive parallel sampling, while the sub-figure on the right shows our method ADJUST. In this example, each denoising step generates three tokens. For this illustration, assume that the diffusion model $f$ has support on only two distinct sentences, "The cat sat on the mat" and "The dog ran in the yard". Sampling tokens in parallel (left), generates an incoherent sentence with respect to the underlying distribution of diffusion model. On the other hand, our method ADJUST (right) conditions each token sample on the previously sampled tokens through a light-weight network, shown in the figure as $g$. Note that we utilize the same number of forward passes of the diffusion model $f$ as the naive parallel sampler (once every three generated tokens).
  • Figure 4: In this figure we compare adaptive parallel decoding (working left-to-right) to naive parallel decoding and ADJUST. Sampling $K>1$ tokens per diffusion step, for left-to-right decoding shows a steeper increase in NLL (left) than using the base model's unmasking logic. Hence APD generation is not suited for such open-ended tasks and under-performs. The right figure shows GSM8k evaluation for high temperature sampling (temperature=1.0) for Dream-Instruct model. Left-to-right generation is naturally good for reasoning tasks, as shown by higher accuracy under APD when compared to naive/joint for $K=1$ tokens per step. APD leads to a quicker decrease in accuracy when compared to baseline and our method.
  • Figure 5: This figure shows the architecture used for our draft model