d2: Improved Techniques for Training Reasoning Diffusion Language Models
Guanghan Wang, Yair Schiff, Gilad Turok, Volodymyr Kuleshov
TL;DR
This paper introduces d2, a principled RL framework for diffusion language models that uses trajectory-likelihood based policy gradients. It presents two specialized estimators, d2-StepMerge and d2-AnyOrder, to efficiently estimate trajectory likelihood under masked diffusion and analyzes any-order causality as a crucial property for diffusion-based reasoning. Empirically, d2 achieves state-of-the-art reasoning performance on Sudoku, Countdown, GSM8K, and MATH500 without supervised chain-of-thought fine-tuning, and demonstrates strong toxicity steering capabilities in a red-teaming setup. The work advances the practicality of RL-only post-training for DLMs and provides theoretical guarantees on estimator accuracy and applicability. Overall, d2 furnishes a scalable, theoretically grounded path to enhancing reasoning in diffusion-based language models while clarifying the role of any-order decoding in this context.
Abstract
While diffusion language models (DLMs) have achieved competitive performance in text generation, improving their reasoning ability with reinforcement learning remains an active research area. Here, we introduce d2, a reasoning framework tailored for masked DLMs. Central to our framework is a new policy gradient algorithm that relies on properties of masking to accurately estimate the likelihoods of sampling trajectories. Our estimators trade off computation for approximation accuracy in an analytically tractable manner, and are particularly effective for DLMs that support any-order likelihood estimation. We characterize and study this property in popular DLMs and show that it is key for efficient diffusion-based reasoning. Empirically, d2 significantly improves over previous diffusion reasoning frameworks using only RL (without relying on supervised fine-tuning), and sets a new state-of-the-art performance for DLMs on logical reasoning tasks (Countdown and Sudoku) and math reasoning benchmarks (GSM8K and MATH500).
