Table of Contents
Fetching ...

EDIT: Early Diffusion Inference Termination for dLLMs Based on Dynamics of Training Gradients

He-Yen Hsieh, Hong Wang, H. T. Kung

TL;DR

The paper tackles the high computational cost of diffusion-based LLM inference by introducing EDIT, a method that uses training-time optimization metadata—specifically AdamW evolution captured during LoRA-finetuned QKV projections—to inform adaptive early termination of diffusion denoising. EDIT preserves a compact representation of learned reasoning pathways and compares current token activations against this map via cosine similarity, using KL divergence over matched token support to detect stable reasoning. The approach yields substantial inference-speedups (11.8%–68.3% fewer denoising steps) while maintaining or improving accuracy on several reasoning benchmarks, with minimal storage overhead (~1.5 MB) and no architectural changes. The work demonstrates that training dynamics contain valuable signals for inference decisions, opening avenues for dynamic compute and more efficient deployment of diffusion-based models, contingent on access to optimization metadata and task-specific tuning.

Abstract

Diffusion-based large language models (dLLMs) refine token generations through iterative denoising, but answers often stabilize before all steps complete. We propose EDIT (Early Diffusion Inference Termination), an inference-time criterion that adaptively stops denoising once sufficient reasoning stability relative to training-time reasoning is detected. EDIT monitors the alignment between token activations and a reasoning map derived from AdamW-aggregated LoRA updates captured during supervised fine-tuning (SFT). During training, optimization dynamics generate rich metadata about parameter importance that in prior methods is typically discarded upon model release. We preserve this information as a compact representation of learned reasoning pathways. During inference, alignment scores are converted to a distribution over the tokens already unmasked at the current denoising step, and convergence is detected when KL divergence between consecutive steps falls below a threshold on the matched unmasked (visible) tokens. Across reasoning benchmarks, EDIT reduces diffusion steps by 11.8% to 68.3% while preserving or improving accuracy in most settings, with approximately 0.02% storage overhead (about 1.5-2 MB for all QKV modules across 32 blocks in an 8 GB model). By utilizing training-gradient dynamics, our work opens a new research direction for reducing dLLM inference time and cost.

EDIT: Early Diffusion Inference Termination for dLLMs Based on Dynamics of Training Gradients

TL;DR

The paper tackles the high computational cost of diffusion-based LLM inference by introducing EDIT, a method that uses training-time optimization metadata—specifically AdamW evolution captured during LoRA-finetuned QKV projections—to inform adaptive early termination of diffusion denoising. EDIT preserves a compact representation of learned reasoning pathways and compares current token activations against this map via cosine similarity, using KL divergence over matched token support to detect stable reasoning. The approach yields substantial inference-speedups (11.8%–68.3% fewer denoising steps) while maintaining or improving accuracy on several reasoning benchmarks, with minimal storage overhead (~1.5 MB) and no architectural changes. The work demonstrates that training dynamics contain valuable signals for inference decisions, opening avenues for dynamic compute and more efficient deployment of diffusion-based models, contingent on access to optimization metadata and task-specific tuning.

Abstract

Diffusion-based large language models (dLLMs) refine token generations through iterative denoising, but answers often stabilize before all steps complete. We propose EDIT (Early Diffusion Inference Termination), an inference-time criterion that adaptively stops denoising once sufficient reasoning stability relative to training-time reasoning is detected. EDIT monitors the alignment between token activations and a reasoning map derived from AdamW-aggregated LoRA updates captured during supervised fine-tuning (SFT). During training, optimization dynamics generate rich metadata about parameter importance that in prior methods is typically discarded upon model release. We preserve this information as a compact representation of learned reasoning pathways. During inference, alignment scores are converted to a distribution over the tokens already unmasked at the current denoising step, and convergence is detected when KL divergence between consecutive steps falls below a threshold on the matched unmasked (visible) tokens. Across reasoning benchmarks, EDIT reduces diffusion steps by 11.8% to 68.3% while preserving or improving accuracy in most settings, with approximately 0.02% storage overhead (about 1.5-2 MB for all QKV modules across 32 blocks in an 8 GB model). By utilizing training-gradient dynamics, our work opens a new research direction for reducing dLLM inference time and cost.

Paper Structure

This paper contains 42 sections, 9 theorems, 30 equations, 11 figures, 4 tables, 1 algorithm.

Key Result

lemma 1

If $D_{t-\Omega+1}, \ldots, D_t \leq \delta$, then

Figures (11)

  • Figure 1: Gradient-based analysis of training–inference alignment on GPQA (seq. 128, 2nd block). Root mean square (RMS) pseudo-gradients $\tilde{G}_{t,B}$ across steps are compared with the SFT gradient mean (dashed) and variance band (shaded). The convergence point (yellow $\blacktriangledown$) occurs at step 19, after which pseudo-gradients stabilize near the SFT mean, indicating that $\sim$20 steps per block preserve fidelity while reducing computation (Table \ref{['tab:efficiency']}, 40.3 steps for two blocks).
  • Figure 2: Performance breakdown across GPQA subdomains comparing EDIT (red) with baseline SFT (green). EDIT shows particularly strong improvements in Molecular Biology and Astrophysics, where reasoning patterns are more structured. The domain-specific variation validates that training metadata captures specialized reasoning pathways.
  • Figure 3: Task-specific parameter activation patterns revealed by AdamW evolution. Different GPQA subdomains (Astrophysics vs. Molecular Biology) engage distinct parameter subsets in the LoRA-B matrix of the Query projection (transformer.block.31). The 3D visualization shows how parameter importance varies across tasks, demonstrating that training metadata captures specialized reasoning pathways.
  • Figure 4: Visualization of LoRA-B parameter updates at training step 105. A $4096\times128$ LoRA projection produces $524,288$ parameters, reshaped into a $256\times256\times8$ grid with the Z-axis showing normalized AdamW update magnitudes (scaled 0–255). Pronounced peaks indicate parameters critical for reasoning tasks, demonstrating that optimization dynamics create clear importance signatures.
  • Figure 5: Visualization of LoRA-A parameter updates at training step 105. A $128\times4096$ LoRA projection produces $524,288$ parameters, reshaped into a $256\times256\times8$ grid with the Z-axis showing normalized AdamW update magnitudes (scaled 0–255). Pronounced peaks indicate parameters critical for reasoning tasks, demonstrating that optimization dynamics create clear importance signatures.
  • ...and 6 more figures

Theorems & Definitions (18)

  • lemma 1: Run-length KL implies multi-step TV bound
  • proof
  • theorem 1: Local argmax invariance certificate
  • proof
  • theorem 3: Tail movement bound and global argmax preservation
  • proof
  • theorem 4: Stability of Lipschitz functionals
  • proof
  • corollary 1: PAC-style guarantee for the final answer
  • proof
  • ...and 8 more