Table of Contents
Fetching ...

Latent Refinement Decoding: Enhancing Diffusion-Based Language Models by Refining Belief States

Qinglin Zhu, Yizhen Yao, Runcong Zhao, Yanzheng Xiang, Amrutha Saseendran, Chen Jin, Philip Teare, Bin Liang, Yulan He, Lin Gui

TL;DR

Latent Refinement Decoding (LRD) tackles high inference latency and information loss in diffusion-based language models by introducing a two-stage decoding pipeline. Phase 1 performs distribution-preserving latent refinement in embedding space through soft embeddings and entropy-guided mixing, while Phase 2 employs a Predictive Feedback Loop to progressively finalize tokens using a KL-divergence-based stopping criterion. The method yields consistent accuracy improvements and speedups up to $10.6×$ across coding and reasoning benchmarks, with robustness to context length and model family. By preserving distributional information and providing principled convergence checks, LRD serves as a versatile drop-in decoding mechanism for diffusion LMs and can integrate with system-level accelerations such as KV caching and speculative decoding.

Abstract

Autoregressive (AR) models remain the standard for natural language generation but still suffer from high latency due to strictly sequential decoding. Recent diffusion-inspired approaches, such as LlaDA and Dream, mitigate this by generating in parallel, yet they suffer from two core limitations: information loss, as predictive distributions for non-finalized tokens are discarded at each step, and premature commitment, where local decisions are made without sufficient global coordination. We introduce Latent Refinement Decoding (LRD), a two-stage framework with Latent Refinement and a Predictive Feedback Loop. The first stage maintains masked positions as distributional mixtures of predicted tokens and the mask embedding, allowing the model to establish more globally consistent beliefs. The second stage progressively finalizes confident tokens while retaining uncertain ones for iterative feedback. KL-divergence dynamics provide a principled and reliable criterion for convergence and early stopping. Experiments across coding (HumanEval +6.3, MBPP +2.6) and reasoning (GSM8K +2.9, MATH500 +3.8) show that LRD improves accuracy while delivering speedups of up to 10.6x, making it a strong and versatile alternative for parallel sequence generation.

Latent Refinement Decoding: Enhancing Diffusion-Based Language Models by Refining Belief States

TL;DR

Latent Refinement Decoding (LRD) tackles high inference latency and information loss in diffusion-based language models by introducing a two-stage decoding pipeline. Phase 1 performs distribution-preserving latent refinement in embedding space through soft embeddings and entropy-guided mixing, while Phase 2 employs a Predictive Feedback Loop to progressively finalize tokens using a KL-divergence-based stopping criterion. The method yields consistent accuracy improvements and speedups up to across coding and reasoning benchmarks, with robustness to context length and model family. By preserving distributional information and providing principled convergence checks, LRD serves as a versatile drop-in decoding mechanism for diffusion LMs and can integrate with system-level accelerations such as KV caching and speculative decoding.

Abstract

Autoregressive (AR) models remain the standard for natural language generation but still suffer from high latency due to strictly sequential decoding. Recent diffusion-inspired approaches, such as LlaDA and Dream, mitigate this by generating in parallel, yet they suffer from two core limitations: information loss, as predictive distributions for non-finalized tokens are discarded at each step, and premature commitment, where local decisions are made without sufficient global coordination. We introduce Latent Refinement Decoding (LRD), a two-stage framework with Latent Refinement and a Predictive Feedback Loop. The first stage maintains masked positions as distributional mixtures of predicted tokens and the mask embedding, allowing the model to establish more globally consistent beliefs. The second stage progressively finalizes confident tokens while retaining uncertain ones for iterative feedback. KL-divergence dynamics provide a principled and reliable criterion for convergence and early stopping. Experiments across coding (HumanEval +6.3, MBPP +2.6) and reasoning (GSM8K +2.9, MATH500 +3.8) show that LRD improves accuracy while delivering speedups of up to 10.6x, making it a strong and versatile alternative for parallel sequence generation.

Paper Structure

This paper contains 19 sections, 13 equations, 5 figures, 4 tables.

Figures (5)

  • Figure 1: Comparison between the existing decoding strategy and the proposed method. Different colours represent distinct tokens, while gradient colours indicate predicted token representations. Top: In the existing strategy, all [MASK] tokens share the same embedding and are repeatedly remasked if not selected. Bottom: In LRD, Phase 1 refines each [MASK] embedding, and Phase 2 progressively commits confident tokens while keeping uncertain ones soft for context-aware decoding.
  • Figure 2: KL divergence between step-wise predictive distributions and final decoded results for LLaDA-1.5 and Dream-Ins across benchmarks. The red vertical line marks where decoding begins after a fixed 20-step latent refinement.
  • Figure 3: Convergence ratios across latent refinement steps for LLaDA-1.5 and Dream-Ins on four benchmarks. Since computing the difference in KL divergence requires at least three consecutive steps, the curves are plotted starting from step 2.
  • Figure 4: Accuracy of Dream-Ins on four benchmarks under different Maximum token proportion, where $r_f$=0 corresponds to the no mixing.
  • Figure 5: Effect of top-$p$ mixing on Dream-Ins across four benchmarks. The purple curve shows the log fraction of tokens included in the mixture.