KLASS: KL-Guided Fast Inference in Masked Diffusion Models
Seo Hyun Kim, Sunwoo Hong, Hojung Jung, Youngrok Park, Se-Young Yun
TL;DR
KLASS presents a training-free, adaptive sampling method for masked diffusion models that uses token-level confidence and KL divergence to identify stable tokens for parallel unmasking. By dynamically selecting tokens with low KL drift and high confidence, KLASS achieves substantial speedups (up to 2.78x) while improving performance on reasoning tasks and generalizing to text, images, and molecules. The approach relies on a simple mechanism—define conf and KL scores, maintain a KL history, and unmask stable tokens—without extra training or heavy overhead. Empirically, KLASS delivers robust gains across diverse domains, offering a practical, scalable solution for faster and more reliable diffusion-based generation.
Abstract
Masked diffusion models have demonstrated competitive results on various tasks including language generation. However, due to its iterative refinement process, the inference is often bottlenecked by slow and static sampling speed. To overcome this problem, we introduce `KL-Adaptive Stability Sampling' (KLASS), a fast yet effective sampling method that exploits token-level KL divergence to identify stable, high-confidence predictions. By unmasking multiple tokens in each iteration without any additional model training, our approach speeds up generation significantly while maintaining sample quality. On reasoning benchmarks, KLASS achieves up to $2.78\times$ wall-clock speedups while improving performance over standard greedy decoding, attaining state-of-the-art results among diffusion-based samplers. We further validate KLASS across diverse domains, including text, image, and molecular generation, showing its effectiveness as a broadly applicable sampler across different models.
