A2D: Any-Order, Any-Step Safety Alignment for Diffusion Language Models
Wonje Jeung, Sangyeon Yoon, Yoonjun Cho, Dongjae Jeon, Sangwoo Shin, Hyesoo Hong, Albert No
TL;DR
This work addresses the safety vulnerabilities of diffusion large language models (dLLMs) that emerge due to any-order, any-step decoding. It introduces A2D, a token-level safety alignment that trains dLLMs to replace harmful spans with an [EOS] refusal signal at masked positions, enabling robust prevention of unsafe content across arbitrary decoding trajectories. Through a specialized Harmful/Retain dataset and uniform masking during training, A2D achieves deep, order- and step-robust safety that suppresses attacks like DIJA while preserving general capabilities, and it enables real-time safety monitoring via the [EOS] probability to trigger early rejection. Empirical results across multiple dLLMs and benchmark suites show substantial reductions in harmful outputs (e.g., DIJA ASR dropping to near zero) with only modest alignment tax, along with favorable performance in extreme safety tests and efficient inference. The approach also demonstrates potential generality beyond diffusion architectures, suggesting a broader framework for token-level safety in generative models.
Abstract
Diffusion large language models (dLLMs) enable any-order generation, but this flexibility enlarges the attack surface: harmful spans may appear at arbitrary positions, and template-based prefilling attacks such as DIJA bypass response-level refusals. We introduce A2D (Any-Order, Any-Step Defense), a token-level alignment method that aligns dLLMs to emit an [EOS] refusal signal whenever harmful content arises. By aligning safety directly at the token-level under randomized masking, A2D achieves robustness to both any-decoding-order and any-step prefilling attacks under various conditions. It also enables real-time monitoring: dLLMs may begin a response but automatically terminate if unsafe continuation emerges. On safety benchmarks, A2D consistently prevents the generation of harmful outputs, slashing DIJA success rates from over 80% to near-zero (1.3% on LLaDA-8B-Instruct, 0.0% on Dream-v0-Instruct-7B), and thresholded [EOS] probabilities allow early rejection, yielding up to 19.3x faster safe termination.
