Table of Contents
Fetching ...

Stopping Computation for Converged Tokens in Masked Diffusion-LM Decoding

Daisuke Oba, Danushka Bollegala, Masahiro Kaneko, Naoaki Okazaki

TL;DR

The work tackles the inefficiency of MDLM decoding where all tokens are reprocessed each step, incurring a costly $O(N^2 d)$ attention cost. It introduces SureLock, a convergence-based locking scheme that permanently freezes stabilized, unmasked tokens and caches their $K/V$ to allow remaining tokens to attend to them, reducing per-step cost to $O(M N d)$ and producing a monotonically decreasing compute profile as sampling proceeds. A theoretical bound ties the local KL divergence at the lock step to a bound on the final token probability deviation, providing a principled justification for the locking rule. Empirically, SureLock achieves 30–50% algorithmic FLOP reductions on LLaDA-8B-Instruct with comparable generation quality and demonstrates complementary benefits when combined with orthogonal acceleration methods, offering practical gains for longer-context diffusion decoding while retaining practicality of deployment.

Abstract

Masked Diffusion Language Models generate sequences via iterative sampling that progressively unmasks tokens. However, they still recompute the attention and feed-forward blocks for every token position at every step -- even when many unmasked tokens are essentially fixed, resulting in substantial waste in compute. We propose SureLock: when the posterior at an unmasked position has stabilized across steps (our sure condition), we lock that position -- thereafter skipping its query projection and feed-forward sublayers -- while caching its attention keys and values so other positions can continue to attend to it. This reduces the dominant per-iteration computational cost from $O(N^2d)$ to $O(MNd)$ where $N$ is the sequence length, $M$ is the number of unlocked token positions, and $d$ is the model dimension. In practice, $M$ decreases as the iteration progresses, yielding substantial savings. On LLaDA-8B, SureLock reduces algorithmic FLOPs by 30--50% relative to the same sampler without locking, while maintaining comparable generation quality. We also provide a theoretical analysis to justify the design rationale of SureLock: monitoring only the local KL at the lock step suffices to bound the deviation in final token probabilities. Our code will be available at https://daioba.github.io/surelock .

Stopping Computation for Converged Tokens in Masked Diffusion-LM Decoding

TL;DR

The work tackles the inefficiency of MDLM decoding where all tokens are reprocessed each step, incurring a costly attention cost. It introduces SureLock, a convergence-based locking scheme that permanently freezes stabilized, unmasked tokens and caches their to allow remaining tokens to attend to them, reducing per-step cost to and producing a monotonically decreasing compute profile as sampling proceeds. A theoretical bound ties the local KL divergence at the lock step to a bound on the final token probability deviation, providing a principled justification for the locking rule. Empirically, SureLock achieves 30–50% algorithmic FLOP reductions on LLaDA-8B-Instruct with comparable generation quality and demonstrates complementary benefits when combined with orthogonal acceleration methods, offering practical gains for longer-context diffusion decoding while retaining practicality of deployment.

Abstract

Masked Diffusion Language Models generate sequences via iterative sampling that progressively unmasks tokens. However, they still recompute the attention and feed-forward blocks for every token position at every step -- even when many unmasked tokens are essentially fixed, resulting in substantial waste in compute. We propose SureLock: when the posterior at an unmasked position has stabilized across steps (our sure condition), we lock that position -- thereafter skipping its query projection and feed-forward sublayers -- while caching its attention keys and values so other positions can continue to attend to it. This reduces the dominant per-iteration computational cost from to where is the sequence length, is the number of unlocked token positions, and is the model dimension. In practice, decreases as the iteration progresses, yielding substantial savings. On LLaDA-8B, SureLock reduces algorithmic FLOPs by 30--50% relative to the same sampler without locking, while maintaining comparable generation quality. We also provide a theoretical analysis to justify the design rationale of SureLock: monitoring only the local KL at the lock step suffices to bound the deviation in final token probabilities. Our code will be available at https://daioba.github.io/surelock .
Paper Structure (54 sections, 1 theorem, 46 equations, 15 figures, 6 tables, 1 algorithm)

This paper contains 54 sections, 1 theorem, 46 equations, 15 figures, 6 tables, 1 algorithm.

Key Result

Theorem 1

Fix a position $i$ that is unlocked up to $t^\star$ and then locked (Alg. alg:surelock). Under (A1)--(A4), for any terminal time $T>t^\star$, In particular, if the locking test enforces $D_{t^\star}^{(i)}\le \varepsilon$, then the terminal log-probability error is at most $\delta = C_{\mathrm{tail}}\sqrt{\varepsilon}$, so the closed-form threshold is

Figures (15)

  • Figure 1: Conceptual figure of Iterative sampling.(a) Nomal sampler (Baseline) recomputes attention scores and FFN sublayers for every token position at every step even after the marginal tokens have become unmasked. (b) SureLock permanently stops recomputing for locked positions once these positions are locked. Via cached $K/V$, other tokens still attend to locked tokens.
  • Figure 2: Step-wise FLOPs ratio. Ratio of step-wise algorithmic FLOPs, i.e., $\mathcal{F}^{t}_{\text{prop}}/\mathcal{F}^{t}_{\text{base}}$, consistently decreases as steps proceed, explaining later-step savings of computational cost.
  • Figure 3: Throughput behavior with SureLock: (a) end-to-end TPS ratio across different ${N_{\text{gen}}}$ and batch size $B$.; (b) per-step TPS ratio increasing as sampling progresses.
  • Figure 4: Comparison of responses between Baseline vs. SureLock on LLaDA-8B-Instrut with $\varepsilon=5e-4$, ${N_{\text{gen}}}=128$, $S=128$. The question is sampled from MT-bench with question id$=119$.
  • Figure 5: Dynamics of Step-wise KL divegence. Averaged step-wise KL divergence measured using LLaDA-8B-Instruct on MT-Bench prompts.
  • ...and 10 more figures

Theorems & Definitions (2)

  • Theorem 1: Locking error bound and closed-form threshold
  • proof