Principled and Tractable RL for Reasoning with Diffusion Language Models
Anthony Zhan
TL;DR
This paper tackles the challenge of applying reinforcement learning post-training to diffusion language models (dLLMs), where standard policy-gradient methods are not directly compatible with the diffusion denoising process. It introduces Amortized GRPO (AGRPO), a principled on-policy RL algorithm that computes unbiased policy gradients for dLLMs via Monte Carlo sampling over timesteps, avoiding biased one-step approximations. AGRPO also incorporates low-discrepancy sampling and practical design choices (caching, k-tuning) to make online RL feasible for diffusion models. Empirically, AGRPO improves reasoning benchmarks (GSM8K, MATH, Countdown) relative to both diffusion baselines and autoregressive RL methods, and it expands the compute/quality frontier by achieving higher accuracy with fewer sampling steps, highlighting its practical impact for efficient, robust reasoning with dLLMs.
Abstract
Diffusion large language models (dLLMs) are a new paradigm of non-autoregressive language models that are trained to predict multiple tokens in parallel and generate text via iterative unmasking. Recent works have successfully pretrained dLLMs to parity with autoregressive LLMs at the 8B scale, but dLLMs have yet to benefit from modern post-training techniques, e.g. reinforcement learning (RL), that have proven effective for autoregressive models. Crucially, algorithms designed for traditional LLMs aren't directly compatible with diffusion frameworks due to inherent differences in modeling assumptions. Moreover, existing attempts at dLLM post-training with RL rely on heuristic-based objectives with no theoretical grounding. In this work, we present Amortized Group Relative Policy Optimization (AGRPO), a principled on-policy RL algorithm designed specifically for dLLMs. AGRPO uses Monte Carlo sampling to compute an unbiased policy gradient estimate, making it the first tractable, faithful adaptation of policy gradient methods for dLLMs. We demonstrate AGRPO's effectiveness on different math/reasoning tasks, a common setting for RL with LLMs, achieving up to +7.6% absolute gain on GSM8K and 3.8x performance on the Countdown task over the baseline LLaDA-8B-Instruct model and 1.3x performance gains over comparable RL methods such as diffu-GRPO. Furthermore, these gains persist across different numbers of sampling steps at inference time, achieving better tradeoffs between compute and performance. Our results demonstrate that online RL algorithms can be extended to diffusion LLMs in principled ways, maintaining both theoretical soundness and practical effectiveness.
