EDIT: Early Diffusion Inference Termination for dLLMs Based on Dynamics of Training Gradients
He-Yen Hsieh, Hong Wang, H. T. Kung
TL;DR
The paper tackles the high computational cost of diffusion-based LLM inference by introducing EDIT, a method that uses training-time optimization metadata—specifically AdamW evolution captured during LoRA-finetuned QKV projections—to inform adaptive early termination of diffusion denoising. EDIT preserves a compact representation of learned reasoning pathways and compares current token activations against this map via cosine similarity, using KL divergence over matched token support to detect stable reasoning. The approach yields substantial inference-speedups (11.8%–68.3% fewer denoising steps) while maintaining or improving accuracy on several reasoning benchmarks, with minimal storage overhead (~1.5 MB) and no architectural changes. The work demonstrates that training dynamics contain valuable signals for inference decisions, opening avenues for dynamic compute and more efficient deployment of diffusion-based models, contingent on access to optimization metadata and task-specific tuning.
Abstract
Diffusion-based large language models (dLLMs) refine token generations through iterative denoising, but answers often stabilize before all steps complete. We propose EDIT (Early Diffusion Inference Termination), an inference-time criterion that adaptively stops denoising once sufficient reasoning stability relative to training-time reasoning is detected. EDIT monitors the alignment between token activations and a reasoning map derived from AdamW-aggregated LoRA updates captured during supervised fine-tuning (SFT). During training, optimization dynamics generate rich metadata about parameter importance that in prior methods is typically discarded upon model release. We preserve this information as a compact representation of learned reasoning pathways. During inference, alignment scores are converted to a distribution over the tokens already unmasked at the current denoising step, and convergence is detected when KL divergence between consecutive steps falls below a threshold on the matched unmasked (visible) tokens. Across reasoning benchmarks, EDIT reduces diffusion steps by 11.8% to 68.3% while preserving or improving accuracy in most settings, with approximately 0.02% storage overhead (about 1.5-2 MB for all QKV modules across 32 blocks in an 8 GB model). By utilizing training-gradient dynamics, our work opens a new research direction for reducing dLLM inference time and cost.
