Table of Contents
Fetching ...

KLASS: KL-Guided Fast Inference in Masked Diffusion Models

Seo Hyun Kim, Sunwoo Hong, Hojung Jung, Youngrok Park, Se-Young Yun

TL;DR

KLASS presents a training-free, adaptive sampling method for masked diffusion models that uses token-level confidence and KL divergence to identify stable tokens for parallel unmasking. By dynamically selecting tokens with low KL drift and high confidence, KLASS achieves substantial speedups (up to 2.78x) while improving performance on reasoning tasks and generalizing to text, images, and molecules. The approach relies on a simple mechanism—define conf and KL scores, maintain a KL history, and unmask stable tokens—without extra training or heavy overhead. Empirically, KLASS delivers robust gains across diverse domains, offering a practical, scalable solution for faster and more reliable diffusion-based generation.

Abstract

Masked diffusion models have demonstrated competitive results on various tasks including language generation. However, due to its iterative refinement process, the inference is often bottlenecked by slow and static sampling speed. To overcome this problem, we introduce `KL-Adaptive Stability Sampling' (KLASS), a fast yet effective sampling method that exploits token-level KL divergence to identify stable, high-confidence predictions. By unmasking multiple tokens in each iteration without any additional model training, our approach speeds up generation significantly while maintaining sample quality. On reasoning benchmarks, KLASS achieves up to $2.78\times$ wall-clock speedups while improving performance over standard greedy decoding, attaining state-of-the-art results among diffusion-based samplers. We further validate KLASS across diverse domains, including text, image, and molecular generation, showing its effectiveness as a broadly applicable sampler across different models.

KLASS: KL-Guided Fast Inference in Masked Diffusion Models

TL;DR

KLASS presents a training-free, adaptive sampling method for masked diffusion models that uses token-level confidence and KL divergence to identify stable tokens for parallel unmasking. By dynamically selecting tokens with low KL drift and high confidence, KLASS achieves substantial speedups (up to 2.78x) while improving performance on reasoning tasks and generalizing to text, images, and molecules. The approach relies on a simple mechanism—define conf and KL scores, maintain a KL history, and unmask stable tokens—without extra training or heavy overhead. Empirically, KLASS delivers robust gains across diverse domains, offering a practical, scalable solution for faster and more reliable diffusion-based generation.

Abstract

Masked diffusion models have demonstrated competitive results on various tasks including language generation. However, due to its iterative refinement process, the inference is often bottlenecked by slow and static sampling speed. To overcome this problem, we introduce `KL-Adaptive Stability Sampling' (KLASS), a fast yet effective sampling method that exploits token-level KL divergence to identify stable, high-confidence predictions. By unmasking multiple tokens in each iteration without any additional model training, our approach speeds up generation significantly while maintaining sample quality. On reasoning benchmarks, KLASS achieves up to wall-clock speedups while improving performance over standard greedy decoding, attaining state-of-the-art results among diffusion-based samplers. We further validate KLASS across diverse domains, including text, image, and molecular generation, showing its effectiveness as a broadly applicable sampler across different models.

Paper Structure

This paper contains 68 sections, 2 theorems, 27 equations, 5 figures, 19 tables, 1 algorithm.

Key Result

Proposition 5.3

Suppose $p_\theta$ is a conditional $\delta$‑approximation of $\pi$. For any context path $c_M\!\to c_{M-1}\!\to\cdots\!\to c_0$ (changing only variables outside $X_i$) ending at $c_0=c^\star$, let $P_t:=p_\theta(\,\cdot\,\mid c_t)$ and $\Delta:=\tfrac{1}{2}(\beta+\gamma-2\delta)_+$. Then

Figures (5)

  • Figure 1: KL divergence as a strong indicator of solution correctness. (a) The Top-$k$ method selects an incorrect solution despite high confidence, whereas KLASS identifies the correct solution, which exhibits a significantly lower KL divergence. (b) KL divergence distributions for the LLaDA and DREAM models show that correct predictions consistently have lower KL divergence than incorrect ones across all datasets.
  • Figure 2: Illustration of parallel decoding with KLASS. Tokens are unmasked when they meet the two criteria: high predictive confidence and a stable probability distribution. Stability is measured by a low KL divergence between consecutive steps (illustrated with history length of 1 for simplicity). On the right it shows the sampling process for position 245: it remains masked due to low confidence or high KL score, and is unmasked when both conditions are satisfied.
  • Figure 3: KL Effect Across Confidence Levels on MATH.
  • Figure 4: MDLM Generated Sample (512 tokens)
  • Figure 5: KLASS Generated Sample (512 tokens)

Theorems & Definitions (10)

  • Definition 4.1
  • Definition 4.2
  • Definition 5.1
  • Definition 5.2
  • Proposition 5.3
  • proof
  • Definition A.1
  • Definition A.2
  • Proposition A.3
  • proof