Table of Contents
Fetching ...

A2D: Any-Order, Any-Step Safety Alignment for Diffusion Language Models

Wonje Jeung, Sangyeon Yoon, Yoonjun Cho, Dongjae Jeon, Sangwoo Shin, Hyesoo Hong, Albert No

TL;DR

This work addresses the safety vulnerabilities of diffusion large language models (dLLMs) that emerge due to any-order, any-step decoding. It introduces A2D, a token-level safety alignment that trains dLLMs to replace harmful spans with an [EOS] refusal signal at masked positions, enabling robust prevention of unsafe content across arbitrary decoding trajectories. Through a specialized Harmful/Retain dataset and uniform masking during training, A2D achieves deep, order- and step-robust safety that suppresses attacks like DIJA while preserving general capabilities, and it enables real-time safety monitoring via the [EOS] probability to trigger early rejection. Empirical results across multiple dLLMs and benchmark suites show substantial reductions in harmful outputs (e.g., DIJA ASR dropping to near zero) with only modest alignment tax, along with favorable performance in extreme safety tests and efficient inference. The approach also demonstrates potential generality beyond diffusion architectures, suggesting a broader framework for token-level safety in generative models.

Abstract

Diffusion large language models (dLLMs) enable any-order generation, but this flexibility enlarges the attack surface: harmful spans may appear at arbitrary positions, and template-based prefilling attacks such as DIJA bypass response-level refusals. We introduce A2D (Any-Order, Any-Step Defense), a token-level alignment method that aligns dLLMs to emit an [EOS] refusal signal whenever harmful content arises. By aligning safety directly at the token-level under randomized masking, A2D achieves robustness to both any-decoding-order and any-step prefilling attacks under various conditions. It also enables real-time monitoring: dLLMs may begin a response but automatically terminate if unsafe continuation emerges. On safety benchmarks, A2D consistently prevents the generation of harmful outputs, slashing DIJA success rates from over 80% to near-zero (1.3% on LLaDA-8B-Instruct, 0.0% on Dream-v0-Instruct-7B), and thresholded [EOS] probabilities allow early rejection, yielding up to 19.3x faster safe termination.

A2D: Any-Order, Any-Step Safety Alignment for Diffusion Language Models

TL;DR

This work addresses the safety vulnerabilities of diffusion large language models (dLLMs) that emerge due to any-order, any-step decoding. It introduces A2D, a token-level safety alignment that trains dLLMs to replace harmful spans with an [EOS] refusal signal at masked positions, enabling robust prevention of unsafe content across arbitrary decoding trajectories. Through a specialized Harmful/Retain dataset and uniform masking during training, A2D achieves deep, order- and step-robust safety that suppresses attacks like DIJA while preserving general capabilities, and it enables real-time safety monitoring via the [EOS] probability to trigger early rejection. Empirical results across multiple dLLMs and benchmark suites show substantial reductions in harmful outputs (e.g., DIJA ASR dropping to near zero) with only modest alignment tax, along with favorable performance in extreme safety tests and efficient inference. The approach also demonstrates potential generality beyond diffusion architectures, suggesting a broader framework for token-level safety in generative models.

Abstract

Diffusion large language models (dLLMs) enable any-order generation, but this flexibility enlarges the attack surface: harmful spans may appear at arbitrary positions, and template-based prefilling attacks such as DIJA bypass response-level refusals. We introduce A2D (Any-Order, Any-Step Defense), a token-level alignment method that aligns dLLMs to emit an [EOS] refusal signal whenever harmful content arises. By aligning safety directly at the token-level under randomized masking, A2D achieves robustness to both any-decoding-order and any-step prefilling attacks under various conditions. It also enables real-time monitoring: dLLMs may begin a response but automatically terminate if unsafe continuation emerges. On safety benchmarks, A2D consistently prevents the generation of harmful outputs, slashing DIJA success rates from over 80% to near-zero (1.3% on LLaDA-8B-Instruct, 0.0% on Dream-v0-Instruct-7B), and thresholded [EOS] probabilities allow early rejection, yielding up to 19.3x faster safe termination.

Paper Structure

This paper contains 57 sections, 8 equations, 12 figures, 14 tables, 1 algorithm.

Figures (12)

  • Figure 1: Average attack success rates on Zeroshot, PAIR, ReNeLLM, Prefilling, and DIJA, evaluated on three instruction-tuned dLLMs. A2D consistently achieves the lowest value compared to other baselines.
  • Figure 2: Overview of A2D for aligning dLLMs. Response-level methods supervise refusals only at the level of full responses, while A2D applies token-level alignment by replacing harmful spans with [EOS] tokens, enabling the model to reject unsafe content under any-order and at any-step. A2D prevents template-based attacks from producing harmful outputs, whereas response-level alignment fails under the same setting.
  • Figure 3: Per-token KL divergence between aligned and base dLLMs. Aligned models (LLaDA-Instruct, LLaDA-1.5) vs. Base model (LLaDA-Base) on Harmful BeaverTails under three decoding strategies. All results are averaged over 150 harmful prompts from BeaverTails, with shaded regions indicating standard deviation.
  • Figure 4: Per-token KL divergence between A2D-aligned and base dLLMs. Aligned models (LLaDA-1.5, LLaDA-1.5-A2D) vs. Base model (LLaDA-Base) on Harmful BeaverTails under three sampling strategies. LLaDA-1.5-A2D refers to LLaDA-1.5 further aligned with A2D for safety. All results are averaged over 150 harmful prompts from BeaverTails, with shaded regions indicating standard deviation.
  • Figure 5: Attack success rates on extreme conditions for three instruction-tuned dLLMs across four alignment methods.
  • ...and 7 more figures